Commit to see if Dev branch is broken
parent
f5fa07bddc
commit
6fce4da365
|
|
@ -50,6 +50,7 @@ class SchedulerType(Enum):
|
|||
CONSTANT = "constant"
|
||||
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
||||
|
||||
|
||||
def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
|
||||
"""
|
||||
Returns a learning rate scheduler based on the REx (Relative Exploration) algorithm.
|
||||
|
|
@ -61,6 +62,7 @@ def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
|
|||
Returns:
|
||||
A tuple containing the original optimizer object and a lambda function that can be used to create a PyTorch learning rate scheduler.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
# https://arxiv.org/abs/2107.04197
|
||||
max_lr = 1
|
||||
|
|
@ -68,7 +70,7 @@ def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
|
|||
d = 0.9
|
||||
|
||||
if current_step < total_training_steps:
|
||||
progress = (current_step / total_training_steps)
|
||||
progress = current_step / total_training_steps
|
||||
div = (1 - d) + (d * (1 - progress))
|
||||
return min_lr + (max_lr - min_lr) * ((1 - progress) / div)
|
||||
else:
|
||||
|
|
@ -77,11 +79,9 @@ def get_rex_scheduler(optimizer: Optimizer, total_training_steps):
|
|||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
|
||||
|
||||
# region Newer Schedulers
|
||||
def get_cosine_annealing_scheduler(
|
||||
optimizer: Optimizer, max_iter: int = 500, eta_min: float = 1e-6
|
||||
optimizer: Optimizer, max_iter: int = 500, eta_min: float = 1e-6
|
||||
):
|
||||
"""
|
||||
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
|
||||
|
|
@ -101,7 +101,7 @@ def get_cosine_annealing_scheduler(
|
|||
|
||||
|
||||
def get_cosine_annealing_warm_restarts_scheduler(
|
||||
optimizer: Optimizer, t_0: int = 25, t_mult: int = 1, eta_min: float = 1e-6
|
||||
optimizer: Optimizer, t_0: int = 25, t_mult: int = 1, eta_min: float = 1e-6
|
||||
):
|
||||
"""
|
||||
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
|
||||
|
|
@ -125,7 +125,7 @@ def get_cosine_annealing_warm_restarts_scheduler(
|
|||
|
||||
|
||||
def get_linear_schedule(
|
||||
optimizer: Optimizer, start_factor: float = 0.5, total_iters: int = 500
|
||||
optimizer: Optimizer, start_factor: float = 0.5, total_iters: int = 500
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases at a linear rate until it reaches the number of total iters,
|
||||
|
|
@ -147,7 +147,7 @@ def get_linear_schedule(
|
|||
|
||||
|
||||
def get_constant_schedule(
|
||||
optimizer: Optimizer, factor: float = 1.0, total_iters: int = 500
|
||||
optimizer: Optimizer, factor: float = 1.0, total_iters: int = 500
|
||||
):
|
||||
"""
|
||||
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
||||
|
|
@ -168,9 +168,10 @@ def get_constant_schedule(
|
|||
|
||||
# endregion
|
||||
|
||||
|
||||
# region originals
|
||||
def get_constant_schedule_with_warmup(
|
||||
optimizer: Optimizer, num_warmup_steps: int, min_lr: float
|
||||
optimizer: Optimizer, num_warmup_steps: int, min_lr: float
|
||||
):
|
||||
"""
|
||||
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
||||
|
|
@ -198,7 +199,7 @@ def get_constant_schedule_with_warmup(
|
|||
|
||||
|
||||
def get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1
|
||||
optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
||||
|
|
@ -234,12 +235,12 @@ def get_linear_schedule_with_warmup(
|
|||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr: float,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr: float,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
|
|
@ -279,12 +280,12 @@ def get_cosine_schedule_with_warmup(
|
|||
|
||||
|
||||
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr: float,
|
||||
num_cycles: int = 1,
|
||||
last_epoch: int = -1,
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr: float,
|
||||
num_cycles: int = 1,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
|
|
@ -326,13 +327,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(
|
|||
|
||||
|
||||
def get_polynomial_decay_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps,
|
||||
num_training_steps,
|
||||
min_lr: float,
|
||||
lr_end=1e-7,
|
||||
power=1.0,
|
||||
last_epoch=-1,
|
||||
optimizer,
|
||||
num_warmup_steps,
|
||||
num_training_steps,
|
||||
min_lr: float,
|
||||
lr_end=1e-7,
|
||||
power=1.0,
|
||||
last_epoch=-1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
||||
|
|
@ -490,11 +491,9 @@ def get_scheduler(
|
|||
)
|
||||
|
||||
if name == SchedulerType.REX:
|
||||
return get_rex_scheduler(
|
||||
optimizer,
|
||||
total_training_steps=total_training_steps
|
||||
)
|
||||
|
||||
return get_rex_scheduler(optimizer, total_training_steps=total_training_steps)
|
||||
|
||||
|
||||
class UniversalScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -571,8 +570,8 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
try:
|
||||
if optimizer == "Adafactor":
|
||||
from transformers.optimization import Adafactor
|
||||
adafactor = Adafactor(
|
||||
params=params_to_optimize,
|
||||
return Adafactor(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
|
|
@ -581,25 +580,23 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
scale_parameter=True,
|
||||
warmup_init=False,
|
||||
)
|
||||
return adafactor
|
||||
|
||||
|
||||
elif optimizer == "CAME":
|
||||
from pytorch_optimizer import CAME
|
||||
came = CAME(
|
||||
params=params_to_optimize,
|
||||
return CAME(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
weight_decouple=True,
|
||||
fixed_decay=False,
|
||||
clip_threshold=1.0,
|
||||
ams_bound=False,
|
||||
)
|
||||
return came
|
||||
)
|
||||
|
||||
elif optimizer == "8bit AdamW":
|
||||
from bitsandbytes.optim import AdamW8bit
|
||||
adamw8bit = AdamW8bit(
|
||||
params=params_to_optimize,
|
||||
return AdamW8bit(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
percentile_clipping=100,
|
||||
|
|
@ -608,12 +605,11 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
amsgrad=False,
|
||||
is_paged=False,
|
||||
)
|
||||
return adamw8bit
|
||||
|
||||
elif optimizer == "Paged 8bit AdamW":
|
||||
|
||||
elif optimizer == "Paged 8bit AdamW":
|
||||
from bitsandbytes.optim import PagedAdamW8bit
|
||||
pagedadamw8bit = PagedAdamW8bit(
|
||||
params=params_to_optimize,
|
||||
return PagedAdamW8bit(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
|
|
@ -623,51 +619,47 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
amsgrad=False,
|
||||
paged=True,
|
||||
)
|
||||
return pagedadamw8bit
|
||||
|
||||
elif optimizer == "Apollo":
|
||||
from pytorch_optimizer import Apollo
|
||||
apollo = Apollo(
|
||||
params=params_to_optimize,
|
||||
return Apollo(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
eight_decay_type='l2',
|
||||
eight_decay_type="l2",
|
||||
init_lr=None,
|
||||
rebound='constant',
|
||||
rebound="constant",
|
||||
)
|
||||
return apollo
|
||||
|
||||
|
||||
elif optimizer == "Lion":
|
||||
from pytorch_optimizer import Lion
|
||||
lion = Lion(
|
||||
params=params_to_optimize,
|
||||
return Lion(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
weight_decouple=True,
|
||||
fixed_decay=False,
|
||||
use_gc=False,
|
||||
adanorm=False
|
||||
adanorm=False,
|
||||
)
|
||||
return lion
|
||||
|
||||
|
||||
elif optimizer == "8bit Lion":
|
||||
from bitsandbytes.optim import Lion8bit
|
||||
lion8bit = Lion8bit(
|
||||
params=params_to_optimize,
|
||||
return Lion8bit(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(0.9, 0.99),
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=weight_decay,
|
||||
is_paged=False,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
min_8bit_size=4096,
|
||||
)
|
||||
return lion8bit
|
||||
|
||||
|
||||
elif optimizer == "Paged 8bit Lion":
|
||||
from bitsandbytes.optim import PagedLion8bit
|
||||
pagedLion8bit = PagedLion8bit(
|
||||
params=params_to_optimize,
|
||||
return PagedLion8bit(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
|
|
@ -676,12 +668,11 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
is_paged=True,
|
||||
min_8bit_size=4096,
|
||||
)
|
||||
return pagedLion8bit
|
||||
|
||||
elif optimizer == "AdamW Dadaptation":
|
||||
from dadaptation import DAdaptAdam
|
||||
dadaptadam = DAdaptAdam(
|
||||
params=params_to_optimize,
|
||||
return DAdaptAdam(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
decouple=True,
|
||||
|
|
@ -689,77 +680,33 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
log_every=log_dadapt(True),
|
||||
fsdp_in_use=False,
|
||||
)
|
||||
return dadaptadam
|
||||
|
||||
elif optimizer == "Lion Dadaptation":
|
||||
from dadaptation import DAdaptLion
|
||||
dadaptlion = DAdaptLion(
|
||||
params=params_to_optimize,
|
||||
return DAdaptLion(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
log_every=log_dadapt(True),
|
||||
fsdp_in_use=False,
|
||||
d0=0.000001,
|
||||
)
|
||||
return dadaptlion
|
||||
|
||||
elif optimizer == "Adan Dadaptation":
|
||||
from dadaptation import DAdaptAdan
|
||||
dadaptadan = DAdaptAdan(
|
||||
params=params_to_optimize,
|
||||
return DAdaptAdan(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
log_every=log_dadapt(True),
|
||||
no_prox=False,
|
||||
d0=0.000001,
|
||||
)
|
||||
return dadaptadan
|
||||
|
||||
elif optimizer == "AdanIP Dadaptation":
|
||||
from dadaptation.experimental import DAdaptAdanIP
|
||||
dadaptadanip = DAdaptAdanIP(
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
log_every=log_dadapt(True),
|
||||
no_prox=False,
|
||||
d0=0.000001
|
||||
)
|
||||
return dadaptadanip
|
||||
|
||||
|
||||
elif optimizer == "SGD Dadaptation":
|
||||
from dadaptation import DAdaptSGD
|
||||
dadaptsgd = DAdaptSGD(
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
log_every=log_dadapt(True),
|
||||
momentum=0.0,
|
||||
fsdp_in_use=False,
|
||||
d0=0.000001,
|
||||
)
|
||||
return dadaptsgd
|
||||
|
||||
elif optimizer == "Prodigy":
|
||||
from pytorch_optimizer import Prodigy
|
||||
prodigy = Prodigy(
|
||||
params=params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
safeguard_warmup=False,
|
||||
d0=1e-6,
|
||||
d_coef=1.0,
|
||||
bias_correction=False,
|
||||
fixed_decay=False,
|
||||
weight_decouple=True,
|
||||
)
|
||||
return prodigy
|
||||
|
||||
|
||||
elif optimizer == "Sophia":
|
||||
from pytorch_optimizer import SophiaH
|
||||
sophia = SophiaH(
|
||||
params=params_to_optimize,
|
||||
return DAdaptSGD(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
weight_decouple=True,
|
||||
|
|
@ -767,20 +714,18 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
hessian_distribution="gaussian",
|
||||
p=0.01,
|
||||
)
|
||||
return sophia
|
||||
|
||||
|
||||
elif optimizer == "Tiger":
|
||||
from pytorch_optimizer import Tiger
|
||||
tiger = Tiger(
|
||||
params=params_to_optimize,
|
||||
return Tiger(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
beta = 0.965,
|
||||
beta=0.965,
|
||||
weight_decay=0.01,
|
||||
weight_decouple=True,
|
||||
fixed_decay=False,
|
||||
)
|
||||
return tiger
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception importing {optimizer}: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
@ -796,7 +741,6 @@ def get_optimizer(optimizer: str, learning_rate: float, weight_decay: float, par
|
|||
)
|
||||
|
||||
|
||||
|
||||
def get_noise_scheduler(args):
|
||||
if args.noise_scheduler == "DEIS":
|
||||
scheduler_class = DEISMultistepScheduler
|
||||
|
|
|
|||
|
|
@ -43,8 +43,7 @@ def load_auto_settings():
|
|||
lowvram = ws.cmd_opts.lowvram
|
||||
config = ws.cmd_opts.config
|
||||
device = ws.device
|
||||
current_epoch = 0
|
||||
current_step = 0
|
||||
|
||||
|
||||
def set_model(new_model):
|
||||
global sd_model
|
||||
|
|
@ -162,8 +161,7 @@ class DreamState:
|
|||
time_left_force_display = False
|
||||
active = False
|
||||
new_ui = False
|
||||
current_epoch = 0
|
||||
current_step = 0
|
||||
|
||||
|
||||
def interrupt(self):
|
||||
if self.status_handler:
|
||||
|
|
@ -195,8 +193,6 @@ class DreamState:
|
|||
"last_status": self.textinfo,
|
||||
"sample_prompts": self.sample_prompts,
|
||||
"active": self.active,
|
||||
"current_epoch": self.current_epoch,
|
||||
"current_step": self.current_step,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
|
@ -311,7 +307,7 @@ def load_vars(root_path = None):
|
|||
data_path, show_progress_every_n_steps, parallel_processing_allowed, dataset_filename_word_regex, dataset_filename_join_string, \
|
||||
device_id, state, disable_safe_unpickle, ckptfix, medvram, lowvram, debug, profile_db, sub_quad_q_chunk_size, sub_quad_kv_chunk_size, \
|
||||
sub_quad_chunk_threshold, CLIP_stop_at_last_layers, sd_model, config, force_cpu, paths, is_auto, device, orig_tensor_to, orig_layer_norm, \
|
||||
orig_tensor_numpy, extension_path, orig_cumsum, orig_Tensor_cumsum, status, state, current_epoch, current_step
|
||||
orig_tensor_numpy, extension_path, orig_cumsum, orig_Tensor_cumsum, status, state
|
||||
|
||||
script_path = os.sep.join(__file__.split(os.sep)[0:-4]) if root_path is None else root_path
|
||||
models_path = os.path.join(script_path, "models")
|
||||
|
|
@ -332,8 +328,7 @@ def load_vars(root_path = None):
|
|||
medvram = False
|
||||
lowvram = False
|
||||
debug = False
|
||||
current_epoch = 0
|
||||
current_step = 0
|
||||
|
||||
profile_db = False
|
||||
sub_quad_q_chunk_size = 1024
|
||||
sub_quad_kv_chunk_size = None
|
||||
|
|
@ -405,8 +400,7 @@ ckptfix = False
|
|||
medvram = False
|
||||
lowvram = False
|
||||
debug = False
|
||||
current_epoch = 0
|
||||
current_step = 0
|
||||
|
||||
profile_db = False
|
||||
sub_quad_q_chunk_size = 1024
|
||||
sub_quad_kv_chunk_size = None
|
||||
|
|
|
|||
|
|
@ -13,16 +13,13 @@ import traceback
|
|||
from contextlib import ExitStack
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from accelerate.utils.megatron_lm import prepare_scheduler
|
||||
from numpy import dtype, float32
|
||||
from pandas import Float32Dtype
|
||||
|
||||
import tomesd
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
import torch.backends.cudnn
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator, cpu_offload
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils.random import set_seed as set_seed2
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
|
@ -59,10 +56,9 @@ from dreambooth.utils.model_utils import (
|
|||
enable_safe_unpickle,
|
||||
xformerify,
|
||||
torch2ify, unet_attn_processors_state_dict,
|
||||
|
||||
)
|
||||
from dreambooth.utils.text_utils import encode_hidden_state
|
||||
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
|
||||
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
|
||||
patch_accelerator_for_fp16_training)
|
||||
from dreambooth.webhook import send_training_update
|
||||
from dreambooth.xattention import optim_to
|
||||
|
|
@ -345,12 +341,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
pbar2.update()
|
||||
pbar2.set_postfix(refresh=True)
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
if args.full_mixed_precision:
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
else:
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
if args.model_type == "SDXL":
|
||||
# import correct text encoder class
|
||||
|
|
@ -362,11 +366,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
pbar2.update()
|
||||
pbar2.set_postfix(refresh=True)
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder_2",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32, )
|
||||
if args.full_mixed_precision:
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder_2",
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
else:
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder_2",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
printm("Created tenc")
|
||||
pbar2.set_description("Loading VAE...")
|
||||
|
|
@ -376,12 +389,20 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
|
||||
pbar2.set_description("Loading unet...")
|
||||
pbar2.update()
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
if args.full_mixed_precision:
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
else:
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
if args.attention == "xformers" and not shared.force_cpu:
|
||||
xformerify(unet, use_lora=args.use_lora)
|
||||
|
|
@ -469,6 +490,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
)
|
||||
|
||||
# Create shared unet/tenc learning rate variables
|
||||
|
||||
learning_rate = args.learning_rate
|
||||
txt_learning_rate = args.txt_learning_rate
|
||||
if args.use_lora:
|
||||
|
|
@ -871,8 +893,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
modules.shared.cmd_opts.disable_safe_unpickle = no_safe
|
||||
global_step = resume_step = args.revision
|
||||
resume_from_checkpoint = True
|
||||
first_epoch = args.lifetime_epoch
|
||||
global_epoch = args.lifetime_epoch
|
||||
first_epoch = args.epoch
|
||||
global_epoch = first_epoch
|
||||
except Exception as lex:
|
||||
logger.warning(f"Exception loading checkpoint: {lex}")
|
||||
logger.debug(" ***** Running training *****")
|
||||
|
|
@ -891,13 +913,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
logger.debug(f" First resume step: {resume_step}")
|
||||
logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}")
|
||||
logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}")
|
||||
logger.debug(f" Min SNR Gamma: {args.min_snr_gamma}")
|
||||
logger.debug(f" EMA: {args.use_ema}")
|
||||
logger.debug(f" UNET: {args.train_unet}")
|
||||
logger.debug(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}")
|
||||
logger.debug(f" Unet LR{' (Lora)' if args.use_lora else ''}: {learning_rate}")
|
||||
logger.debug(f" Tenc LR{' (Lora)' if args.use_lora and stop_text_percentage != 0 else ''}: {tenc_learning_rate}")
|
||||
logger.debug(f" Full Mixed Precision: {args.full_mixed_precision}")
|
||||
logger.debug(f" LR{' (Lora)' if args.use_lora else ''}: {learning_rate}")
|
||||
if stop_text_percentage > 0:
|
||||
logger.debug(f" Tenc LR{' (Lora)' if args.use_lora else ''}: {txt_learning_rate}")
|
||||
logger.debug(f" V2: {args.v2}")
|
||||
|
||||
os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1)
|
||||
|
|
@ -911,10 +932,14 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
save_canceled = status.interrupted
|
||||
save_image = False
|
||||
save_model = False
|
||||
save_lora = False
|
||||
|
||||
if not save_canceled and not save_completed:
|
||||
# Check to see if the number of epochs since last save is gt the interval
|
||||
if 0 < save_model_interval <= session_epoch - last_model_save:
|
||||
save_model = True
|
||||
if args.use_lora:
|
||||
save_lora = True
|
||||
last_model_save = session_epoch
|
||||
|
||||
# Repeat for sample images
|
||||
|
|
@ -927,9 +952,9 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if global_step > 0:
|
||||
save_image = True
|
||||
save_model = True
|
||||
save_lora = True
|
||||
|
||||
save_snapshot = False
|
||||
save_lora = args.use_lora
|
||||
|
||||
if is_epoch_check:
|
||||
if shared.status.do_save_samples:
|
||||
|
|
@ -937,6 +962,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
shared.status.do_save_samples = False
|
||||
|
||||
if shared.status.do_save_model:
|
||||
if args.use_lora:
|
||||
save_lora = True
|
||||
save_model = True
|
||||
shared.status.do_save_model = False
|
||||
|
||||
|
|
@ -1012,7 +1039,6 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
|
||||
printm("Creating pipeline.")
|
||||
if args.model_type == "SDXL":
|
||||
|
||||
s_pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
|
||||
|
|
@ -1025,6 +1051,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
vae=vae.to(accelerator.device),
|
||||
torch_dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
safety_checker=None,
|
||||
requires_safety_checker=None,
|
||||
)
|
||||
xformerify(s_pipeline.unet,use_lora=args.use_lora)
|
||||
else:
|
||||
|
|
@ -1037,6 +1065,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
vae=vae,
|
||||
torch_dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
safety_checker=None,
|
||||
requires_safety_checker=None,
|
||||
)
|
||||
xformerify(s_pipeline.unet,use_lora=args.use_lora)
|
||||
xformerify(s_pipeline.vae,use_lora=args.use_lora)
|
||||
|
|
@ -1124,10 +1154,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config(
|
||||
s_pipeline.scheduler.config)
|
||||
s_pipeline.scheduler.config.solver_type = "bh2"
|
||||
save_lora = False
|
||||
|
||||
elif save_diffusers:
|
||||
# We are saving weights, we need to ensure revision is saved
|
||||
args.save()
|
||||
if "_tmp" not in weights_dir:
|
||||
args.save()
|
||||
try:
|
||||
out_file = None
|
||||
status.textinfo = (
|
||||
|
|
@ -1527,7 +1559,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.s(latents, noise, timesteps)
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
|
|
@ -1544,9 +1576,17 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if args.model_type != "SDXL":
|
||||
# TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth
|
||||
if not args.split_loss and not with_prior_preservation:
|
||||
loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(),
|
||||
reduction="mean")
|
||||
loss *= batch["loss_avg"]
|
||||
if args.min_snr_gamma == 0.0:
|
||||
loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
loss *= batch["loss_avg"]
|
||||
else:
|
||||
# Predict the noise residual
|
||||
if model_pred.shape[1] == 6:
|
||||
|
|
@ -1556,7 +1596,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
|
||||
# Compute instance loss
|
||||
if args.min_snr_gamma == 0.0:
|
||||
# Compute instance loss
|
||||
|
|
@ -1570,13 +1610,24 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(),
|
||||
reduction="mean")
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
# Compute loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
if args.min_snr_gamma == 0.0:
|
||||
# Compute loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Calculate loss with min snr
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
else:
|
||||
if with_prior_preservation:
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
|
|
@ -1598,27 +1649,19 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
if args.min_snr_gamma == 0.0:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and not args.use_lora:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -7,12 +7,18 @@ import random
|
|||
import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
from PIL import features, PngImagePlugin, Image, ExifTags
|
||||
|
||||
import os
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from dreambooth.dataclasses.db_concept import Concept
|
||||
from dreambooth.dataclasses.prompt_data import PromptData
|
||||
from helpers.mytqdm import mytqdm
|
||||
|
|
@ -441,18 +447,8 @@ def load_image_directory(db_dir, concept: Concept, is_class: bool = True) -> Lis
|
|||
return list(zip(img_paths, captions))
|
||||
|
||||
|
||||
|
||||
|
||||
def open_image(image_path: str, return_pil: bool = False) -> Union[np.ndarray, Image.Image]:
|
||||
if return_pil:
|
||||
return Image.open(image_path)
|
||||
else:
|
||||
return np.array(Image.open(image_path))
|
||||
|
||||
def trim_image(image: Union[np.ndarray, Image.Image], reso: Tuple[int, int]) -> Union[np.ndarray, Image.Image]:
|
||||
return image[:reso[0], :reso[1]]
|
||||
|
||||
def open_and_trim(image_path: str, reso: Tuple[int, int], return_pil: bool = False) -> Union[np.ndarray, Image.Image]:
|
||||
def open_and_trim(image_path: str, reso: Tuple[int, int], return_pil: bool = False) -> Union[np.ndarray, Image]:
|
||||
# Open image with PIL
|
||||
image = Image.open(image_path)
|
||||
image = rotate_image_straight(image)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ import sys
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from diffusers.utils import is_xformers_available
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from dreambooth import shared # noqa
|
||||
|
|
@ -293,4 +292,4 @@ def torch2ify(unet):
|
|||
return unet
|
||||
|
||||
def is_xformers_available():
|
||||
import xformers
|
||||
pass
|
||||
|
|
@ -123,13 +123,13 @@ def list_optimizer():
|
|||
pass
|
||||
try:
|
||||
if shared.device.type != "mps":
|
||||
from bitsandbytes.optim.adamw import AdamW8bit
|
||||
from bitsandbytes.optim import AdamW8bit
|
||||
optimizer_list.append("8bit AdamW")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from bitsandbytes.optim.adamw import PagedAdamW8bit
|
||||
from bitsandbytes.optim import PagedAdamW8bit
|
||||
optimizer_list.append("Paged 8bit AdamW")
|
||||
except:
|
||||
pass
|
||||
|
|
@ -171,13 +171,13 @@ def list_optimizer():
|
|||
pass
|
||||
|
||||
try:
|
||||
from bitsandbytes.optim.lion import Lion8bit
|
||||
from bitsandbytes.optim import Lion8bit
|
||||
optimizer_list.append("8bit Lion")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from bitsandbytes.optim.lion import PagedLion8bit
|
||||
from bitsandbytes.optim import PagedLion8bit
|
||||
optimizer_list.append("Paged 8bit Lion")
|
||||
except:
|
||||
pass
|
||||
|
|
|
|||
116
postinstall.py
116
postinstall.py
|
|
@ -33,7 +33,9 @@ def actual_install():
|
|||
except:
|
||||
revision = ""
|
||||
|
||||
print("If submitting an issue on github, please provide the full startup log for debugging purposes.")
|
||||
print(
|
||||
"If submitting an issue on github, please provide the full startup log for debugging purposes."
|
||||
)
|
||||
print("")
|
||||
print("Initializing Dreambooth")
|
||||
print(f"Dreambooth revision: {revision}")
|
||||
|
|
@ -53,7 +55,7 @@ def pip_install(*args):
|
|||
output = subprocess.check_output(
|
||||
[sys.executable, "-m", "pip", "install"] + list(args),
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
)
|
||||
for line in output.decode().split("\n"):
|
||||
if "Successfully installed" in line:
|
||||
print(line)
|
||||
|
|
@ -61,7 +63,9 @@ def pip_install(*args):
|
|||
|
||||
def install_requirements():
|
||||
dreambooth_skip_install = os.environ.get("DREAMBOOTH_SKIP_INSTALL", False)
|
||||
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
||||
req_file = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "requirements.txt"
|
||||
)
|
||||
req_file_startup_arg = os.environ.get("REQS_FILE", "requirements_versions.txt")
|
||||
|
||||
if dreambooth_skip_install or req_file == req_file_startup_arg:
|
||||
|
|
@ -74,12 +78,20 @@ def install_requirements():
|
|||
try:
|
||||
pip_install("-r", req_file)
|
||||
|
||||
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
|
||||
if (
|
||||
has_diffusers
|
||||
and has_tqdm
|
||||
and Version(transformers_version) < Version("4.26.1")
|
||||
):
|
||||
print()
|
||||
print("Does your project take forever to startup?")
|
||||
print("Repetitive dependency installation may be the reason.")
|
||||
print("Automatic1111's base project sets strict requirements on outdated dependencies.")
|
||||
print("If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.")
|
||||
print(
|
||||
"Automatic1111's base project sets strict requirements on outdated dependencies."
|
||||
)
|
||||
print(
|
||||
"If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup."
|
||||
)
|
||||
print()
|
||||
except subprocess.CalledProcessError as grepexc:
|
||||
error_msg = grepexc.stdout.decode()
|
||||
|
|
@ -118,23 +130,34 @@ def check_bitsandbytes():
|
|||
if bitsandbytes_version != "0.41.1":
|
||||
try:
|
||||
print("Installing bitsandbytes")
|
||||
pip_install("--force-install","==prefer-binary","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui","bitsandbytes==0.41.1")
|
||||
pip_install(
|
||||
"--force-install",
|
||||
"==prefer-binary",
|
||||
"--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui",
|
||||
"bitsandbytes==0.41.1",
|
||||
)
|
||||
except:
|
||||
print("Bitsandbytes 0.41.1 installation failed.")
|
||||
print("Some features such as 8bit optimizers will be unavailable")
|
||||
print("Please install manually with")
|
||||
print("'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'")
|
||||
print(
|
||||
"'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'"
|
||||
)
|
||||
pass
|
||||
else:
|
||||
if bitsandbytes_version != "0.41.1":
|
||||
try:
|
||||
print("Installing bitsandbytes")
|
||||
pip_install("--force-install","--prefer-binary","bitsandbytes==0.41.1")
|
||||
pip_install(
|
||||
"--force-install", "--prefer-binary", "bitsandbytes==0.41.1"
|
||||
)
|
||||
except:
|
||||
print("Bitsandbytes 0.41.1 installation failed")
|
||||
print("Some features such as 8bit optimizers will be unavailable")
|
||||
print("Install manually with")
|
||||
print("'python -m pip install bitsandbytes==0.41.1 --prefer-binary --force-install'")
|
||||
print(
|
||||
"'python -m pip install bitsandbytes==0.41.1 --prefer-binary --force-install'"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -149,9 +172,10 @@ class Dependency:
|
|||
def check_versions():
|
||||
import platform
|
||||
from sys import platform as sys_platform
|
||||
is_mac = sys_platform == 'darwin' and platform.machine() == 'arm64'
|
||||
|
||||
#Probably a bad idea but update ALL the dependencies
|
||||
is_mac = sys_platform == "darwin" and platform.machine() == "arm64"
|
||||
|
||||
# Probably a bad idea but update ALL the dependencies
|
||||
dependencies = [
|
||||
Dependency(module="xformers", version="0.0.21", required=False),
|
||||
Dependency(module="torch", version="1.13.1" if is_mac else "2.0.1+cu118"),
|
||||
|
|
@ -159,7 +183,7 @@ def check_versions():
|
|||
Dependency(module="accelerate", version="0.22.0"),
|
||||
Dependency(module="diffusers", version="0.20.1"),
|
||||
Dependency(module="transformers", version="4.25.1"),
|
||||
Dependency(module="bitsandbytes", version="0.41.1"),
|
||||
Dependency(module="bitsandbytes", version="0.41.1"),
|
||||
]
|
||||
|
||||
launch_errors = []
|
||||
|
|
@ -180,13 +204,21 @@ def check_versions():
|
|||
required_version = dependency.version
|
||||
required_comparison = dependency.version_comparison
|
||||
|
||||
if required_comparison == "min" and Version(installed_ver) < Version(required_version):
|
||||
if required_comparison == "min" and Version(installed_ver) < Version(
|
||||
required_version
|
||||
):
|
||||
if dependency.required:
|
||||
launch_errors.append(f"{module} is below the required {required_version} version.")
|
||||
launch_errors.append(
|
||||
f"{module} is below the required {required_version} version."
|
||||
)
|
||||
print(f"[!] {module} version {installed_ver} installed.")
|
||||
|
||||
elif required_comparison == "exact" and Version(installed_ver) != Version(required_version):
|
||||
launch_errors.append(f"{module} is not the required {required_version} version.")
|
||||
elif required_comparison == "exact" and Version(installed_ver) != Version(
|
||||
required_version
|
||||
):
|
||||
launch_errors.append(
|
||||
f"{module} is not the required {required_version} version."
|
||||
)
|
||||
print(f"[!] {module} version {installed_ver} installed.")
|
||||
|
||||
else:
|
||||
|
|
@ -204,7 +236,7 @@ def check_versions():
|
|||
|
||||
def print_requirement_installation_error(err):
|
||||
print("# Requirement installation exception:")
|
||||
for line in err.split('\n'):
|
||||
for line in err.split("\n"):
|
||||
line = line.strip()
|
||||
if line:
|
||||
print(line)
|
||||
|
|
@ -213,27 +245,45 @@ def print_requirement_installation_error(err):
|
|||
def print_xformers_installation_error(err):
|
||||
torch_ver = importlib_metadata.version("torch")
|
||||
print()
|
||||
print("#######################################################################################################")
|
||||
print("# XFORMERS ISSUE DETECTED #")
|
||||
print("#######################################################################################################")
|
||||
print(
|
||||
"#######################################################################################################"
|
||||
)
|
||||
print(
|
||||
"# XFORMERS ISSUE DETECTED #"
|
||||
)
|
||||
print(
|
||||
"#######################################################################################################"
|
||||
)
|
||||
print("#")
|
||||
print(f"# Dreambooth could not find a compatible version of xformers (>= 0.0.21 built with torch {torch_ver})")
|
||||
print("# xformers will not be available for Dreambooth. Consider upgrading to Torch 2.")
|
||||
print(
|
||||
f"# Dreambooth could not find a compatible version of xformers (>= 0.0.21 built with torch {torch_ver})"
|
||||
)
|
||||
print(
|
||||
"# xformers will not be available for Dreambooth. Consider upgrading to Torch 2."
|
||||
)
|
||||
print("#")
|
||||
print("# Xformers installation exception:")
|
||||
for line in err.split('\n'):
|
||||
for line in err.split("\n"):
|
||||
line = line.strip()
|
||||
if line:
|
||||
print(line)
|
||||
print("#")
|
||||
print("#######################################################################################################")
|
||||
print(
|
||||
"#######################################################################################################"
|
||||
)
|
||||
|
||||
|
||||
def print_launch_errors(launch_errors):
|
||||
print()
|
||||
print("#######################################################################################################")
|
||||
print("# LIBRARY ISSUE DETECTED #")
|
||||
print("#######################################################################################################")
|
||||
print(
|
||||
"#######################################################################################################"
|
||||
)
|
||||
print(
|
||||
"# LIBRARY ISSUE DETECTED #"
|
||||
)
|
||||
print(
|
||||
"#######################################################################################################"
|
||||
)
|
||||
print("#")
|
||||
print("# " + "\n# ".join(launch_errors))
|
||||
print("#")
|
||||
|
|
@ -242,20 +292,26 @@ def print_launch_errors(launch_errors):
|
|||
print("# TROUBLESHOOTING")
|
||||
print("# 1. Fully restart your project (not just the webpage)")
|
||||
print("# 2. Update your A1111 project and extensions")
|
||||
print("# 3. Dreambooth requirements should have installed automatically, but you can manually install them")
|
||||
print(
|
||||
"# 3. Dreambooth requirements should have installed automatically, but you can manually install them"
|
||||
)
|
||||
print("# by running the following 4 commands from the A1111 project root:")
|
||||
print("cd venv/Scripts")
|
||||
print("activate")
|
||||
print("cd ../..")
|
||||
print("pip install -r ./extensions/sd_dreambooth_extension/requirements.txt")
|
||||
print("#######################################################################################################")
|
||||
print(
|
||||
"#######################################################################################################"
|
||||
)
|
||||
|
||||
|
||||
def check_torch_unsafe_load():
|
||||
try:
|
||||
from modules import safe
|
||||
|
||||
safe.load = safe.unsafe_torch_load
|
||||
import torch
|
||||
|
||||
torch.load = safe.unsafe_torch_load
|
||||
except:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -521,7 +521,7 @@ def on_ui_tabs():
|
|||
label="Max Resolution",
|
||||
step=64,
|
||||
minimum=128,
|
||||
value=512,
|
||||
value=1024,
|
||||
maximum=2048,
|
||||
elem_id="max_res",
|
||||
)
|
||||
|
|
@ -570,7 +570,7 @@ def on_ui_tabs():
|
|||
label="Offset Noise",
|
||||
minimum=-1,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
step=0.00001,
|
||||
value=0,
|
||||
)
|
||||
db_freeze_clip_normalization = gr.Checkbox(
|
||||
|
|
@ -612,8 +612,8 @@ def on_ui_tabs():
|
|||
db_min_snr_gamma = gr.Slider(
|
||||
label="Min SNR Gamma",
|
||||
minimum=0,
|
||||
maximum=10,
|
||||
step=0.1,
|
||||
maximum=20,
|
||||
step=0.01,
|
||||
visible=True,
|
||||
)
|
||||
db_pad_tokens = gr.Checkbox(
|
||||
|
|
|
|||
Loading…
Reference in New Issue