From 8c932bea0a67818bfc1253a07baae390b04870da Mon Sep 17 00:00:00 2001 From: saunderez Date: Thu, 6 Apr 2023 23:41:53 +1000 Subject: [PATCH] Add ToMESD support - enabled under Testing tab --- .gitignore | 3 ++- dreambooth/dataclasses/db_config.py | 5 +++-- dreambooth/train_dreambooth.py | 24 ++++++++++++++++++++++-- scripts/main.py | 2 ++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 51a033a..e460194 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,5 @@ coverage.xml docs/_build/ # PyBuilder -target/ \ No newline at end of file +target/ +.vscode/launch.json diff --git a/dreambooth/dataclasses/db_config.py b/dreambooth/dataclasses/db_config.py index 419f8b2..d202293 100644 --- a/dreambooth/dataclasses/db_config.py +++ b/dreambooth/dataclasses/db_config.py @@ -33,10 +33,11 @@ class DreamboothConfig(BaseModel): deterministic: bool = False disable_logging: bool = False ema_predict: bool = False + enable_tomesd: bool = False epoch: int = 0 epoch_pause_frequency: int = 0 epoch_pause_time: int = 0 - freeze_clip_normalization: bool = True + freeze_clip_normalization: bool = False gradient_accumulation_steps: int = 1 gradient_checkpointing: bool = True gradient_set_to_none: bool = True @@ -105,7 +106,7 @@ class DreamboothConfig(BaseModel): src: str = "" stop_text_encoder: float = 1.0 strict_tokens: bool = False - tenc_weight_decay: float = 0.00 + tenc_weight_decay: float = 0.01 tenc_grad_clip_norm: float = 0.00 tf32_enable: bool = False train_batch_size: int = 1 diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index 2b5a297..3278bb8 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -16,6 +16,7 @@ import torch import torch.backends.cuda import torch.backends.cudnn import torch.utils.checkpoint +import tomesd from accelerate import Accelerator from accelerate.utils.random import set_seed as set_seed2 from diffusers import ( @@ -77,6 +78,8 @@ dl.set_verbosity_error() last_samples = [] last_prompts = [] + + try: diff_version = importlib_metadata.version("diffusers") version_string = diff_version.split(".") @@ -163,6 +166,9 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult: result = TrainResult result.config = args + enable_tomesd = args.enable_tomesd + enable_tomesd = True + set_seed(args.deterministic) @find_executable_batch_size( @@ -844,8 +850,12 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult: s_pipeline = s_pipeline.to(accelerator.device) + printm("Patching model with tomesd.") + tomesd.apply_patch(s_pipeline, ratio=0.5) + with accelerator.autocast(), torch.inference_mode(): if save_model: + tomesd.remove_patch(s_pipeline) # We are saving weights, we need to ensure revision is saved args.save() try: @@ -885,7 +895,11 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult: ) pbar.update() + printm("Patching model with tomesd.") + tomesd.apply_patch(s_pipeline, ratio=0.5) + elif save_lora: + tomesd.remove_patch(s_pipeline) pbar.set_description("Saving Lora Weights...") # setup directory loras_dir = os.path.join(args.model_dir, "loras") @@ -939,11 +953,15 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult: compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file, log=False, snap_rev=snap_rev, pbar=pbar) printm("Restored, moved to acc.device.") + + printm("Patching model with tomesd.") + tomesd.apply_patch(s_pipeline, ratio=0.5) + except Exception as ex: print(f"Exception saving checkpoint/model: {ex}") traceback.print_exc() pass - + tomesd.remove_patch(s_pipeline) save_dir = args.model_dir if save_image: samples = [] @@ -1009,12 +1027,14 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult: last_prompts.append(prompt) del samples del prompts - + printm("Patching model with tomesd.") + tomesd.apply_patch(s_pipeline, ratio=0.5) except Exception as em: print(f"Exception saving sample: {em}") traceback.print_exc() pass printm("Starting cleanup.") + tomesd.remove_patch(s_pipeline) del s_pipeline if save_image: if "generator" in locals(): diff --git a/scripts/main.py b/scripts/main.py index 36a2030..87746ef 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -986,6 +986,7 @@ def on_ui_tabs(): ) with gr.Tab("Testing", elem_id="TabDebug"): gr.HTML(value="Experimental Settings") + db_enable_tomesd = gr.Checkbox(label="Enable ToMeSD") db_disable_logging = gr.Checkbox(label="Disable Logging") db_deterministic = gr.Checkbox(label="Deterministic") db_ema_predict = gr.Checkbox(label="Use EMA for prediction") @@ -1236,6 +1237,7 @@ def on_ui_tabs(): db_deterministic, db_disable_logging, db_ema_predict, + db_enable_tomesd, db_epoch_pause_frequency, db_epoch_pause_time, db_epochs,