Further fixes, SDXL training working (kinda)

401c60e
saunderez 2023-09-03 15:30:40 +10:00
parent eba7cd7c5a
commit 401c60ead4
5 changed files with 28 additions and 27 deletions

View File

@ -643,7 +643,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
weight_decay=weight_decay,
log_every=log_dadapt(True),
no_prox=False,
d0=0.000001
d0=0.000001,
)
return dadaptadan
@ -686,7 +686,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
elif optimizer == "SGD Dadaptation":
from dadaptation import DAdaptSGD
return DAdaptSGD(
dadaptsgd = DAdaptSGD(
params=params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
@ -695,6 +695,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
fsdp_in_use=False,
d0=0.000001,
)
return dadptsgd
elif optimizer == "Prodigy":
from pytorch_optimizer import Prodigy
@ -727,7 +728,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
from pytorch_optimizer import CAME
came = CAME(
params = params_to_optimize,
params=params_to_optimize,
lr=learning_rate,
weight_decay=weight_decay,
weight_decouple=True,
@ -740,7 +741,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
elif optimizer == "Lion8bit":
from bitsandbytes.optim import Lion8bit
lion8bit = Lion8bit(
params = params_to_optimize,
params=params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.99),
weight_decay=weight_decay,
@ -754,7 +755,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
elif optimizer == "PagedLion8bit":
from bitsandbytes.optim import PagedLion8bit
pagedLion8bit = PagedLion8bit(
params_to_optimize,
params=params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.99),
weight_decay=0,
@ -768,7 +769,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
elif optimizer == "PagedAdamW8bit":
from bitsandbytes.optim import PagedAdamW8bit
pagedadamw8bit = PagedAdamW8bit(
params_to_optimize,
params=params_to_optimize,
lr=learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
@ -795,6 +796,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
)
def get_noise_scheduler(args):
if args.noise_scheduler == "DEIS":
scheduler_class = DEISMultistepScheduler

View File

@ -381,18 +381,10 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
if args.attention == "xformers" and not shared.force_cpu:
if is_xformers_available():
import xformers
xformerify(unet, False)
xformerify(vae, False)
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.21. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
xformerify(unet, False)
xformerify(vae, False)
unet = torch2ify(unet)
@ -658,7 +650,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
stop_profiler(profiler)
return result
if train_dataset.train_img_data.count == 0:
if train_dataset.__len__ == 0:
msg = "Please provide a directory with actual images in it."
logger.warning(msg)
status.textinfo = msg
@ -1131,8 +1123,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
elif save_diffusers:
# We are saving weights, we need to ensure revision is saved
if "_tmp" not in weights_dir:
args.save()
args.save()
try:
out_file = None
status.textinfo = (
@ -1140,7 +1131,6 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
)
update_status({"status": status.textinfo})
pbar2.reset(1)
pbar2.set_description("Saving diffusion model")
s_pipeline.save_pretrained(
weights_dir,

View File

@ -13,12 +13,10 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from PIL import features, PngImagePlugin, Image, ExifTags
import os
from typing import List, Tuple, Dict, Union
import numpy as np
import torch
from dreambooth.dataclasses.db_concept import Concept
from dreambooth.dataclasses.prompt_data import PromptData
from helpers.mytqdm import mytqdm
@ -447,8 +445,19 @@ def load_image_directory(db_dir, concept: Concept, is_class: bool = True) -> Lis
return list(zip(img_paths, captions))
def open_image(image_path: str, return_pil: bool = False) -> Union[np.ndarray, Image.Image]:
if return_pil:
return Image.open(image_path)
else:
return np.array(Image.open(image_path))
def trim_image(image: Union[np.ndarray, Image], reso: Tuple[int, int]) -> Union[np.ndarray, Image]:
return image[:reso[0], :reso[1]]
def open_and_trim(image_path: str, reso: Tuple[int, int], return_pil: bool = False) -> Union[np.ndarray, Image]:
# Open image with PIL
return trim_image(open_image(image_path, return_pil), reso) # Open image with PIL
image = Image.open(image_path)
image = rotate_image_straight(image)

View File

@ -28,7 +28,7 @@ def actual_install():
base_dir = os.path.dirname(os.path.realpath(__file__))
try:
repo = git.Repo()
repo = git.Repo
#I decare thisdecafter the eve ts at
revision = repo.rev_parse("HEAD")
except:
@ -116,7 +116,7 @@ def check_bitsandbytes():
"""
if os.name == "nt":
bitsandbytes_version = importlib_metadata.version("bitsandbytes")
if bitsandbytes_version is not "0.41.1":
if bitsandbytes_version != "0.41.1":
try:
pip_install("--force-install","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui", "--prefer-binary"",bitsandbytes==0.41.1")
#bnb_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bitsandbytes_windows")

View File

@ -11,7 +11,7 @@ tqdm==4.65.0
tomesd~=0.1.2
transformers~=4.32.1; # > 4.26.x causes issues (db extension #1110)
# Get prebuilt Windows wheels from jllllll
bitsandbytes~=0.41.1; sys_platform == 'win32' and platform_machine == 'AMD64' \
bitsandbytes~=0.41.1; sys_platform == 'win32' or platform_machine == 'AMD64' \
--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary
# Get Linux and MacOS wheels from PyPi
bitsandbytes~=0.41.1; sys_platform != 'win32' or platform_machine != 'AMD64' --prefer-binary