Further fixes, SDXL training working (kinda)
parent
eba7cd7c5a
commit
401c60ead4
|
|
@ -643,7 +643,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
weight_decay=weight_decay,
|
||||
log_every=log_dadapt(True),
|
||||
no_prox=False,
|
||||
d0=0.000001
|
||||
d0=0.000001,
|
||||
)
|
||||
return dadaptadan
|
||||
|
||||
|
|
@ -686,7 +686,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
|
||||
elif optimizer == "SGD Dadaptation":
|
||||
from dadaptation import DAdaptSGD
|
||||
return DAdaptSGD(
|
||||
dadaptsgd = DAdaptSGD(
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
|
|
@ -695,6 +695,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
fsdp_in_use=False,
|
||||
d0=0.000001,
|
||||
)
|
||||
return dadptsgd
|
||||
|
||||
elif optimizer == "Prodigy":
|
||||
from pytorch_optimizer import Prodigy
|
||||
|
|
@ -727,7 +728,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
from pytorch_optimizer import CAME
|
||||
came = CAME(
|
||||
|
||||
params = params_to_optimize,
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
weight_decouple=True,
|
||||
|
|
@ -740,7 +741,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
elif optimizer == "Lion8bit":
|
||||
from bitsandbytes.optim import Lion8bit
|
||||
lion8bit = Lion8bit(
|
||||
params = params_to_optimize,
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=weight_decay,
|
||||
|
|
@ -754,7 +755,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
elif optimizer == "PagedLion8bit":
|
||||
from bitsandbytes.optim import PagedLion8bit
|
||||
pagedLion8bit = PagedLion8bit(
|
||||
params_to_optimize,
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
|
|
@ -768,7 +769,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
elif optimizer == "PagedAdamW8bit":
|
||||
from bitsandbytes.optim import PagedAdamW8bit
|
||||
pagedadamw8bit = PagedAdamW8bit(
|
||||
params_to_optimize,
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
|
|
@ -795,6 +796,7 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
)
|
||||
|
||||
|
||||
|
||||
def get_noise_scheduler(args):
|
||||
if args.noise_scheduler == "DEIS":
|
||||
scheduler_class = DEISMultistepScheduler
|
||||
|
|
|
|||
|
|
@ -381,18 +381,10 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if args.attention == "xformers" and not shared.force_cpu:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
xformerify(unet, False)
|
||||
xformerify(vae, False)
|
||||
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.21. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"xformers is not available. Make sure it is installed correctly"
|
||||
)
|
||||
xformerify(unet, False)
|
||||
xformerify(vae, False)
|
||||
|
||||
unet = torch2ify(unet)
|
||||
|
||||
|
|
@ -658,7 +650,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
stop_profiler(profiler)
|
||||
return result
|
||||
|
||||
if train_dataset.train_img_data.count == 0:
|
||||
if train_dataset.__len__ == 0:
|
||||
msg = "Please provide a directory with actual images in it."
|
||||
logger.warning(msg)
|
||||
status.textinfo = msg
|
||||
|
|
@ -1131,8 +1123,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
|
||||
elif save_diffusers:
|
||||
# We are saving weights, we need to ensure revision is saved
|
||||
if "_tmp" not in weights_dir:
|
||||
args.save()
|
||||
args.save()
|
||||
try:
|
||||
out_file = None
|
||||
status.textinfo = (
|
||||
|
|
@ -1140,7 +1131,6 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
)
|
||||
update_status({"status": status.textinfo})
|
||||
pbar2.reset(1)
|
||||
|
||||
pbar2.set_description("Saving diffusion model")
|
||||
s_pipeline.save_pretrained(
|
||||
weights_dir,
|
||||
|
|
|
|||
|
|
@ -13,12 +13,10 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
|||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
from PIL import features, PngImagePlugin, Image, ExifTags
|
||||
|
||||
import os
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
|
@ -447,8 +445,19 @@ def load_image_directory(db_dir, concept: Concept, is_class: bool = True) -> Lis
|
|||
return list(zip(img_paths, captions))
|
||||
|
||||
|
||||
|
||||
|
||||
def open_image(image_path: str, return_pil: bool = False) -> Union[np.ndarray, Image.Image]:
|
||||
if return_pil:
|
||||
return Image.open(image_path)
|
||||
else:
|
||||
return np.array(Image.open(image_path))
|
||||
|
||||
def trim_image(image: Union[np.ndarray, Image], reso: Tuple[int, int]) -> Union[np.ndarray, Image]:
|
||||
return image[:reso[0], :reso[1]]
|
||||
|
||||
def open_and_trim(image_path: str, reso: Tuple[int, int], return_pil: bool = False) -> Union[np.ndarray, Image]:
|
||||
# Open image with PIL
|
||||
return trim_image(open_image(image_path, return_pil), reso) # Open image with PIL
|
||||
image = Image.open(image_path)
|
||||
image = rotate_image_straight(image)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def actual_install():
|
|||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
try:
|
||||
repo = git.Repo()
|
||||
repo = git.Repo
|
||||
#I decare thisdecafter the eve ts at
|
||||
revision = repo.rev_parse("HEAD")
|
||||
except:
|
||||
|
|
@ -116,7 +116,7 @@ def check_bitsandbytes():
|
|||
"""
|
||||
if os.name == "nt":
|
||||
bitsandbytes_version = importlib_metadata.version("bitsandbytes")
|
||||
if bitsandbytes_version is not "0.41.1":
|
||||
if bitsandbytes_version != "0.41.1":
|
||||
try:
|
||||
pip_install("--force-install","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui", "--prefer-binary"",bitsandbytes==0.41.1")
|
||||
#bnb_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bitsandbytes_windows")
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ tqdm==4.65.0
|
|||
tomesd~=0.1.2
|
||||
transformers~=4.32.1; # > 4.26.x causes issues (db extension #1110)
|
||||
# Get prebuilt Windows wheels from jllllll
|
||||
bitsandbytes~=0.41.1; sys_platform == 'win32' and platform_machine == 'AMD64' \
|
||||
bitsandbytes~=0.41.1; sys_platform == 'win32' or platform_machine == 'AMD64' \
|
||||
--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary
|
||||
# Get Linux and MacOS wheels from PyPi
|
||||
bitsandbytes~=0.41.1; sys_platform != 'win32' or platform_machine != 'AMD64' --prefer-binary
|
||||
|
|
|
|||
Loading…
Reference in New Issue