Merge pull request #1158 from saunderez/dev

D-Adapt Additions + OOM Resume stuff + Fixes
pull/1176/head
ArrowM 2023-04-05 17:44:47 -05:00 committed by GitHub
commit bcc266bf29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 210 additions and 68 deletions

View File

@ -105,6 +105,8 @@ class DreamboothConfig(BaseModel):
src: str = ""
stop_text_encoder: float = 1.0
strict_tokens: bool = False
tenc_weight_decay: float = 0.00
tenc_grad_clip_norm: float = 0.00
tf32_enable: bool = False
train_batch_size: int = 1
train_imagic: bool = False
@ -121,6 +123,7 @@ class DreamboothConfig(BaseModel):
def __init__(
self,
model_name: str = "",
model_dir: str = "",
v2: bool = False,
src: str = "",
resolution: int = 512,

View File

@ -37,6 +37,7 @@ logger = logging.get_logger(__name__)
class SchedulerType(Enum):
DADAPT_WITH_WARMUP = "dadapt_with_warmup"
LINEAR = "linear"
LINEAR_WITH_WARMUP = "linear_with_warmup"
COSINE = "cosine"
@ -48,6 +49,39 @@ class SchedulerType(Enum):
CONSTANT_WITH_WARMUP = "constant_with_warmup"
def get_dadapt_with_warmup(optimizer, num_warmup_steps: int=0, unet_lr: int=1.0, tenc_lr: int=1.0):
"""
Adjust LR from initial rate to the minimum specified LR over the maximum number of steps.
See <a href='https://miro.medium.com/max/828/1*Bk4xhtvg_Su42GmiVtvigg.webp'> for an example.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`, *optional*, defaults to 500):
The number of steps for the warmup phase.
unet_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to to control d-dadaption for the UNET
tenc_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to to control d-dadaption for the TENC
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedules for TENC and UNET.
"""
def unet_lambda(current_step: int):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(unet_lr, num_warmup_steps)))
else:
return unet_lr
def tenc_lambda(current_step: int):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(tenc_lr, num_warmup_steps)))
else:
return tenc_lr
return LambdaLR(optimizer, [unet_lambda, tenc_lambda], last_epoch=-1, verbose=False)
# region Newer Schedulers
def get_cosine_annealing_scheduler(
optimizer: Optimizer, max_iter: int = 500, eta_min: float = 1e-6
@ -241,8 +275,7 @@ def get_cosine_schedule_with_warmup(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi *
float(num_cycles) * 2.0 * progress))
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
@ -369,6 +402,8 @@ def get_scheduler(
power: float = 1.0,
factor: float = 0.5,
scale_pos: float = 0.5,
unet_lr: float = 1.0,
tenc_lr: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
@ -396,6 +431,12 @@ def get_scheduler(
scale_pos (`float`, *optional*, defaults to 0.5):
If a lr scheduler has an adjustment point, this is the percentage of training steps at which to
adjust the LR.
unet_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to to control d-dadaption for the UNET
tenc_lr (`float`, *optional*, defaults to 1e-6):
The learning rate used to to control d-dadaption for the TENC
"""
name = SchedulerType(name)
break_steps = int(total_training_steps * scale_pos)
@ -451,6 +492,14 @@ def get_scheduler(
num_cycles=num_cycles,
)
if name == SchedulerType.DADAPT_WITH_WARMUP:
return get_dadapt_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
unet_lr=unet_lr,
tenc_lr=tenc_lr,
)
class UniversalScheduler:
def __init__(
@ -466,9 +515,12 @@ class UniversalScheduler:
lr: float = 1e-6,
min_lr: float = 1e-6,
scale_pos: float = 0.5,
unet_lr: float = 1.0,
tenc_lr: float = 1.0,
):
self.current_step = 0
og_schedulers = [
"dadapt_with_warmup",
"constant_with_warmup",
"linear_with_warmup",
"cosine",
@ -490,6 +542,8 @@ class UniversalScheduler:
power=power,
factor=factor,
scale_pos=scale_pos,
unet_lr=unet_lr,
tenc_lr=tenc_lr,
)
def step(self, steps: int = 1, is_epoch: bool = False):
@ -531,14 +585,6 @@ def get_optimizer(args, params_to_optimize):
weight_decay=args.adamw_weight_decay,
)
# elif args.optimizer == "SGD Dadaptation":
# from dadaptation import DAdaptSGD
# return DAdaptSGD(
# params_to_optimize,
# lr=args.learning_rate,
# weight_decay=args.adamw_weight_decay,
# )
elif args.optimizer == "AdamW Dadaptation":
from dadaptation import DAdaptAdam
return DAdaptAdam(
@ -548,13 +594,14 @@ def get_optimizer(args, params_to_optimize):
decouple=True,
)
# elif args.optimizer == "Adagrad Dadaptation":
# from dadaptation import DAdaptAdaGrad
# return DAdaptAdaGrad(
# params_to_optimize,
# lr=args.learning_rate,
# weight_decay=args.adamw_weight_decay,
# )
elif args.optimizer == "AdanIP Dadaptation":
from dreambooth.dadapt_adan_ip import DAdaptAdanIP
return DAdaptAdanIP(
params_to_optimize,
lr=args.learning_rate,
weight_decay=args.adamw_weight_decay,
log_every=5,
)
elif args.optimizer == "Adan Dadaptation":
from dreambooth.dadapt_adan import DAdaptAdan
@ -562,14 +609,7 @@ def get_optimizer(args, params_to_optimize):
params_to_optimize,
lr=args.learning_rate,
weight_decay=args.adamw_weight_decay,
)
elif args.optimizer == "AdanIP Dadaptation":
from dreambooth.dadapt_adan_ip import DAdaptAdanIP
return DAdaptAdanIP(
params_to_optimize,
lr=args.learning_rate,
weight_decay=args.adamw_weight_decay,
log_every=5,
)

View File

@ -44,6 +44,9 @@ def load_auto_settings():
config = ws.cmd_opts.config
device = ws.device
sd_model = ws.sd_model
in_progress = False
in_progress_epoch = 0
in_progress_step = 0
def set_model(new_model):
global sd_model
@ -162,12 +165,15 @@ class DreamState:
def interrupt(self):
self.interrupted = True
self.in_progress = False
def interrupt_after_save(self):
self.interrupted_after_save = True
self.in_progress = False
def interrupt_after_epoch(self):
self.interrupted_after_epoch = True
self.in_progress = False
def save_samples(self):
self.do_save_samples = True
@ -187,7 +193,10 @@ class DreamState:
"sampling_steps": self.sampling_steps,
"last_status": self.textinfo,
"sample_prompts": self.sample_prompts,
"active": self.active
"active": self.active,
"in_progress": self.in_progress,
"in_progress_epoch": self.in_progress_epoch,
"in_progress_step": self.in_progress_step,
}
return obj
@ -215,6 +224,7 @@ class DreamState:
self.job_count = 0
self.job_no = 0
self.active = False
self.in_progress = False
torch_gc()
def nextjob(self):
@ -297,7 +307,7 @@ def load_vars(root_path = None):
data_path, show_progress_every_n_steps, parallel_processing_allowed, dataset_filename_word_regex, dataset_filename_join_string, \
device_id, state, disable_safe_unpickle, ckptfix, medvram, lowvram, debug, profile_db, sub_quad_q_chunk_size, sub_quad_kv_chunk_size, \
sub_quad_chunk_threshold, CLIP_stop_at_last_layers, sd_model, config, force_cpu, paths, is_auto, device, orig_tensor_to, orig_layer_norm, \
orig_tensor_numpy, extension_path, orig_cumsum, orig_Tensor_cumsum, status, state
orig_tensor_numpy, extension_path, orig_cumsum, orig_Tensor_cumsum, status, state, in_progress, in_progress_epoch, in_progress_step
script_path = os.sep.join(__file__.split(os.sep)[0:-4]) if root_path is None else root_path
logger.debug(f"Script path is {script_path}")
@ -319,6 +329,9 @@ def load_vars(root_path = None):
medvram = False
lowvram = False
debug = False
in_progress = False
in_progress_epoch = 0
in_progress_step = 0
profile_db = False
sub_quad_q_chunk_size = 1024
sub_quad_kv_chunk_size = None
@ -390,6 +403,9 @@ ckptfix = False
medvram = False
lowvram = False
debug = False
in_progress = False
in_progress_epoch = 0
in_progress_step = 0
profile_db = False
sub_quad_q_chunk_size = 1024
sub_quad_kv_chunk_size = None
@ -407,5 +423,6 @@ extension_path = ""
status = None
orig_cumsum = torch.cumsum
orig_Tensor_cumsum = torch.Tensor.cumsum
is_auto = load_auto_settings()
load_vars()

View File

@ -108,6 +108,11 @@ try:
except:
pass
def dadapt(optimizer):
if optimizer == "AdamW Dadaptation" or optimizer == "Adan Dadaptation":
return True
else:
return False
def set_seed(deterministic: bool):
if deterministic:
@ -437,11 +442,10 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
params_to_optimize = unet.parameters()
optimizer = get_optimizer(args, params_to_optimize)
optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay
optimizer.param_groups[1]["grad_clip_norm"] = args.tenc_grad_clip_norm
noise_scheduler = get_noise_scheduler(args)
# tenc_weight_decay = optimizer.param_groups[1]["weight_decay"]
# tenc_weight_decay = args.adamw_weight_decay + 0.02
def cleanup_memory():
try:
if unet:
@ -569,6 +573,8 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
power=args.lr_power,
factor=args.lr_factor,
scale_pos=lr_scale_pos,
unet_lr=args.lora_learning_rate,
tenc_lr=args.lora_txt_learning_rate,
)
# create ema, fix OOM
@ -665,6 +671,15 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
except Exception as lex:
print(f"Exception loading checkpoint: {lex}")
#if shared.in_progress:
# print(" ***** OOM detected. Resuming from last step *****")
# max_train_steps = max_train_steps - shared.in_progress_step
# max_train_epochs = max_train_epochs - shared.in_progress_epoch
# session_epoch = shared.in_progress_epoch
# text_encoder_epochs = (shared.in_progress_epoch/max_train_epochs)*text_encoder_epochs
#else:
# shared.in_progress = True
print(" ***** Running training *****")
if shared.force_cpu:
print(f" TRAINING WITH CPU ONLY")
@ -1260,6 +1275,10 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
optimizer.zero_grad(set_to_none=args.gradient_set_to_none)
#Track current step and epoch for OOM resume
#shared.in_progress_epoch = global_epoch
#shared.in_progress_steps = global_step
allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
last_lr = lr_scheduler.get_last_lr()[0]
@ -1276,27 +1295,58 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
del noisy_latents
del target
if dadapt(args.optimizer):
dlr_unet = optimizer.param_groups[0]["d"]*optimizer.param_groups[0]["lr"]
dlr_tenc = optimizer.param_groups[1]["d"]*optimizer.param_groups[1]["lr"]
loss_step = loss.detach().item()
loss_total += loss_step
if args.split_loss:
logs = {
"lr": float(last_lr),
"loss": float(loss_step),
"inst_loss": float(instance_loss.detach().item()),
"prior_loss": float(prior_loss.detach().item()),
"vram": float(cached),
}
else:
logs = {
"lr": float(last_lr),
"loss": float(loss_step),
"vram": float(cached),
}
status.textinfo2 = (
f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, "
f"VRAM: {allocated}/{cached} GB"
)
if args.split_loss:
if dadapt(args.optimizer):
logs = {
"lr": float(dlr_unet),
#"dlr_tenc": float(dlr_tenc),
"loss": float(loss_step),
"inst_loss": float(instance_loss.detach().item()),
"prior_loss": float(prior_loss.detach().item()),
"vram": float(cached),
}
else:
logs = {
"lr": float(last_lr),
"loss": float(loss_step),
"inst_loss": float(instance_loss.detach().item()),
"prior_loss": float(prior_loss.detach().item()),
"vram": float(cached),
}
else:
if dadapt(args.optimizer):
logs = {
"lr": float(dlr_unet),
#"dlr_tenc": float(dlr_tenc),
"loss": float(loss_step),
"vram": float(cached),
}
else:
logs = {
"lr": float(last_lr),
"loss": float(loss_step),
"vram": float(cached),
}
if dadapt(args.optimizer):
status.textinfo2 = (
f"Loss: {'%.2f' % loss_step}, UNET DLR: {'{:.2E}'.format(Decimal(dlr_unet))}, TENC DLR: {'{:.2E}'.format(Decimal(dlr_tenc))}, "
f"VRAM: {allocated}/{cached} GB"
)
else:
status.textinfo2 = (
f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, "
f"VRAM: {allocated}/{cached} GB"
)
progress_bar.update(train_batch_size)
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=args.revision)
@ -1317,6 +1367,9 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
# Log completion message
if training_complete or status.interrupted:
shared.in_progress = False
shared.in_progress_step = 0
shared.in_progress_epoch - 0
print(" Training complete (step check).")
if status.interrupted:
state = "cancelled"
@ -1369,6 +1422,9 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
if status.interrupted:
training_complete = True
print("Training complete, interrupted.")
shared.in_progress = False
shared.in_progress_step = 0
shared.in_progress_epoch = 0
break
time.sleep(1)

View File

@ -620,6 +620,7 @@ def load_model_params(model_name):
"""
@param model_name: The name of the model to load.
@return:
db_model_dir: The model directory
db_model_path: The full path to the model directory
db_revision: The current revision of the model
db_v2: If the model requires a v2 config/compilation
@ -646,14 +647,13 @@ def load_model_params(model_name):
loras = get_lora_models(config)
db_lora_models = gr_update(choices=loras)
msg = f"Selected model: '{model_name}'."
return (
config.model_dir,
config.revision,
config.epoch,
"True" if config.v2 else "False",
"True" if config.has_ema else "False",
"True" if config.has_ema and not config.use_lora else "False",
config.src,
config.shared_diffusers_path,
db_model_snapshots,
@ -1045,6 +1045,9 @@ def debug_buckets(model_name, num_epochs, batch_size):
optimizer = AdamW(
placeholder, lr=args.learning_rate, weight_decay=args.adamw_weight_decay
)
if not args.use_lora and args.lr_scheduler == "dadapt_with_warmup":
args.lora_learning_rate = args.learning_rate,
args.lora_txt_learning_rate = args.learning_rate,
lr_scheduler = UniversalScheduler(
args.lr_scheduler,
@ -1057,6 +1060,8 @@ def debug_buckets(model_name, num_epochs, batch_size):
factor=args.lr_factor,
scale_pos=args.lr_scale_pos,
min_lr=args.learning_rate_min,
unet_lr=args.lora_learning_rate,
tenc_lr=args.lora_txt_learning_rate,
)
sampler = BucketSampler(dataset, args.train_batch_size, True)

View File

@ -144,6 +144,12 @@ def list_optimizer():
except:
pass
try:
from dreambooth.dadapt_adan_ip import DAdaptAdanIP
optimizer_list.append("AdanIP Dadaptation")
except:
pass
return optimizer_list
@ -180,6 +186,7 @@ def list_schedulers():
"polynomial",
"constant",
"constant_with_warmup",
"dadapt_with_warmup",
]

View File

@ -379,4 +379,4 @@ class LogParser:
except:
pass
print("Cleanup log parse.")
return out_images, out_names
return out_images, out_names

View File

@ -312,11 +312,10 @@ function filterArgs(argsCount, arguments) {
let db_titles = {
"API Key": "Used for securing the Web API. Click the refresh button to the right to (re)generate your key, the trash icon to remove it.",
"AdamW Weight Decay": "The weight decay of the AdamW Optimizer. Values closer to 0 closely match your training dataset, and values closer to 1 generalize more and deviate from your training dataset. Default is 1e-2, values lower than 0.1 are recommended.",
"AdamW Weight Decay": "The weight decay of the AdamW Optimizer. Values closer to 0 closely match your training dataset, and values closer to 1 generalize more and deviate from your training dataset. Default is 1e-2, values lower than 0.1 are recommended. For D-Adaptation values between 0.02 and 0.04 are recommended",
"Amount of time to pause between Epochs (s)": "When 'Pause After N Epochs' is greater than 0, this is the amount of time, in seconds, that training will be paused for",
"Apply Horizontal Flip": "Randomly decide to flip images horizontally.",
"Batch Size": "How many images to process at once per training step?",
"Betas": "The betas of the used by the Dadaptation schedulers. Default is 0.9, 0.999.",
"Cache Latents": "When this box is checked latents will be cached. Caching latents will use more VRAM, but improve training speed.",
"Cancel": "Cancel training.",
"Class Batch Size": "How many classifier/regularization images to generate at once.",
@ -336,11 +335,8 @@ let db_titles = {
"Custom Model Name": "A custom name to use when saving .ckpt and .pt files. Subdirectories will also be named this.",
"Dataset Directory": "The directory containing training images.",
"Debug Buckets": "Examine the instance and class images and report any instance images without corresponding class images.",
"Decouple": "Decouple the weight decay from learning rate.",
"Discord Webhook": "Send training samples to a Discord channel after generation.",
"D0": "Initial D estimate for D-adaptation",
"Existing Prompt Contents": "If using [filewords], this tells the string builder how the existing prompts are formatted.",
"EPS": "The epsilon value to use for the Dadaptation optimizers.",
"Extract EMA Weights": "If EMA weights are saved in a model, these will be extracted instead of the full Unet. Probably not necessary for training or fine-tuning.",
"Freeze CLIP Normalization Layers": "Keep the normalization layers of CLIP frozen during training. Advanced usage, may increase model performance and editability.",
"Generate Ckpt": "Generate a checkpoint at the current training level.",
@ -365,18 +361,18 @@ let db_titles = {
"HuggingFace Token": "Your huggingface token to use for cloning files.",
"Instance Prompt": "A prompt describing the subject. Use [Filewords] to parse image filename/.txt to insert existing prompt here.",
"Instance Token": "When using [filewords], this is the instance identifier that is unique to your subject. Should be a single word.",
"Learning Rate Scheduler": "The learning rate scheduler to use. All schedulers use the provided warmup time except for 'constant'.",
"Learning Rate Scheduler": "The learning rate scheduler to use. All schedulers use the provided warmup time except for 'constant'. For dadapt_with_warmup it 10% total steps is recommended. You may need to add additional epochs to compensate.",
"Learning Rate Warmup Steps": "Number of steps for the warmup in the lr scheduler. LR will start at 0 and increase to this value over the specified number of steps.",
"Learning Rate": "The rate at which the model learns. Default is 2e-6.",
"Learning Rate": "The rate at which the model learns. Default is 2e-6. For optimizers with D-Adaptation recommended learning rate is 1.0",
"Load Settings": "Load last saved training parameters for the model.",
"Log Memory": "Log the current GPU memory usage.",
"Lora Model": "The Lora model to load for continued fine-tuning or checkpoint generation.",
"Use Lora Extended": "Trains the Lora model with resnet layers. This will always improves quality and editability, but leads to bigger files.",
"Lora UNET Rank": "The rank for the Lora UNET (Default 4). Higher values = better quality with large file size. Lower values = sacrifice quality with lower file size. Learning rates work differently at different ranks. Saved loras at high precision (fp32) will lead to larger lora files.",
"Lora Text Encoder Rank": "The rank for the Lora Text Encoder (Default 4). Higher values = better quality with large file size. Lower values = sacrifice quality with lower file size. Learning rates work differently at different ranks. Saved loras at high precision (fp32) will lead to larger lora files.",
"Lora Text Learning Rate": "The learning rate at which to train lora text encoder. Regular learning rate is ignored.",
"Lora Text Learning Rate": "The learning rate at which to train lora text encoder. Regular learning rate is ignored. For optimizers with D-Adaptation recommended LR is 1.0",
"Lora Text Weight": "What percentage of the lora weights should be applied to the text encoder when creating a checkpoint.",
"Lora UNET Learning Rate": "The learning rate at which to train lora unet. Regular learning rate is ignored.",
"Lora UNET Learning Rate": "The learning rate at which to train lora unet. Regular learning rate is ignored. For optimizers with D-Adaptation recommended learning rate is 1.0",
"Lora Weight": "What percentage of the lora weights should be applied to the unet when creating a checkpoint.",
"Max Resolution": "The resolution of input images. When using bucketing, this is the maximum size of image buckets.",
"Max Token Length": "Maximum token length to respect. You probably want to leave this at 75.",
@ -385,8 +381,6 @@ let db_titles = {
"Mixed Precision": "Use FP16 or BF16 (if available) will help improve memory performance. Required when using 'xformers'.",
"Model Path": "The URL to the model on huggingface. Should be in the format of 'developer/model_name'.",
"Model": "The model to train.",
"Momentum": "The momentum to use for the optimizer.",
"NoProx": "How to perform the decoupled weight decay.",
"Name": "The name of the model to create.",
"Number of Hard Resets": "Number of hard resets of the lr in cosine_with_restarts scheduler.",
"Number of Samples to Generate": "How many samples to generate per subject.",
@ -427,6 +421,8 @@ let db_titles = {
"Source Checkpoint": "The source checkpoint to extract for training.",
"Step Ratio of Text Encoder Training": "The number of steps per image (Epoch) to train the text encoder for. Set 0.5 for 50% of the epochs",
"Strict Tokens": "Parses instance prompts separated by the following characters [,;.!?], and prevents breaking up tokens when using the tokenizer. Useful if you have prompts separated by a lot of tags.",
"TENC Grad Clip Norm": "Prevents overfit by clipping gradient norms. Default value is 0.0. Recommended value for Lora is 1.0",
"TENC Weight Decay": "The weight decay for the Text Encoder. Values closer to 0 closely match your training dataset, and values closer to 1 generalize more and deviate from your training dataset. Default is 1e-2. For Dreambooth, recommended value is same as AdamW Weight Decay. For Lora recommended value is 0.01-0.02 higher than AdamW Weight Decay.",
"Total Number of Class/Reg Images": "Total number of classification/regularization images to use. If no images exist, they will be generated. Set to 0 to disable prior preservation.",
"Train Imagic Only": "Uses Imagic for training instead of full dreambooth, useful for training with a single instance image.",
"Train Text Encoder": "Enabling this will provide better results and editability, but cost more VRAM.",

View File

@ -508,7 +508,7 @@ def on_ui_tabs():
label="Learning Rate Warmup Steps",
value=0,
step=5,
maximum=10000,
maximum=1000,
)
with gr.Column():
@ -555,7 +555,7 @@ def on_ui_tabs():
label="Step Ratio of Text Encoder Training",
minimum=0,
maximum=1,
step=0.01,
step=0.05,
value=0,
visible=True,
)
@ -582,8 +582,24 @@ def on_ui_tabs():
label="Weight Decay",
minimum=0,
maximum=1,
step=1e-7,
value=1e-2,
step=0.001,
value=0.01,
visible=True,
)
db_tenc_weight_decay = gr.Slider(
label="TENC Weight Decay",
minimum=0,
maximum=1,
step=0.001,
value=0.01,
visible=True,
)
db_tenc_grad_clip_norm = gr.Slider(
label="TENC Gradient Clip Norm",
minimum=0,
maximum=128,
step=0.25,
value=0,
visible=True,
)
db_pad_tokens = gr.Checkbox(
@ -1287,6 +1303,8 @@ def on_ui_tabs():
db_src,
db_stop_text_encoder,
db_strict_tokens,
db_tenc_grad_clip_norm,
db_tenc_weight_decay,
db_tf32_enable,
db_train_batch_size,
db_train_imagic,
@ -1481,7 +1499,7 @@ def on_ui_tabs():
)
def optimizer_changed(opti):
show_adapt = opti in ["SGD Dadaptation", "AdaGrad Dadaptation", "AdamW Dadaptation", "Adan Dadaptation"]
show_adapt = opti in ["AdamW Dadaptation", "Adan Dadaptation", "AdanIP Dadaptation"]
adaptation_lr = gr.update(visible=show_adapt)
return adaptation_lr