Add ToMESD support - enabled under Testing tab

pull/1176/head
saunderez 2023-04-06 23:41:53 +10:00
parent bcc266bf29
commit 8c932bea0a
4 changed files with 29 additions and 5 deletions

1
.gitignore vendored
View File

@ -60,3 +60,4 @@ docs/_build/
# PyBuilder
target/
.vscode/launch.json

View File

@ -33,10 +33,11 @@ class DreamboothConfig(BaseModel):
deterministic: bool = False
disable_logging: bool = False
ema_predict: bool = False
enable_tomesd: bool = False
epoch: int = 0
epoch_pause_frequency: int = 0
epoch_pause_time: int = 0
freeze_clip_normalization: bool = True
freeze_clip_normalization: bool = False
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = True
gradient_set_to_none: bool = True
@ -105,7 +106,7 @@ class DreamboothConfig(BaseModel):
src: str = ""
stop_text_encoder: float = 1.0
strict_tokens: bool = False
tenc_weight_decay: float = 0.00
tenc_weight_decay: float = 0.01
tenc_grad_clip_norm: float = 0.00
tf32_enable: bool = False
train_batch_size: int = 1

View File

@ -16,6 +16,7 @@ import torch
import torch.backends.cuda
import torch.backends.cudnn
import torch.utils.checkpoint
import tomesd
from accelerate import Accelerator
from accelerate.utils.random import set_seed as set_seed2
from diffusers import (
@ -77,6 +78,8 @@ dl.set_verbosity_error()
last_samples = []
last_prompts = []
try:
diff_version = importlib_metadata.version("diffusers")
version_string = diff_version.split(".")
@ -163,6 +166,9 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
result = TrainResult
result.config = args
enable_tomesd = args.enable_tomesd
enable_tomesd = True
set_seed(args.deterministic)
@find_executable_batch_size(
@ -844,8 +850,12 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
s_pipeline = s_pipeline.to(accelerator.device)
printm("Patching model with tomesd.")
tomesd.apply_patch(s_pipeline, ratio=0.5)
with accelerator.autocast(), torch.inference_mode():
if save_model:
tomesd.remove_patch(s_pipeline)
# We are saving weights, we need to ensure revision is saved
args.save()
try:
@ -885,7 +895,11 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
)
pbar.update()
printm("Patching model with tomesd.")
tomesd.apply_patch(s_pipeline, ratio=0.5)
elif save_lora:
tomesd.remove_patch(s_pipeline)
pbar.set_description("Saving Lora Weights...")
# setup directory
loras_dir = os.path.join(args.model_dir, "loras")
@ -939,11 +953,15 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file,
log=False, snap_rev=snap_rev, pbar=pbar)
printm("Restored, moved to acc.device.")
printm("Patching model with tomesd.")
tomesd.apply_patch(s_pipeline, ratio=0.5)
except Exception as ex:
print(f"Exception saving checkpoint/model: {ex}")
traceback.print_exc()
pass
tomesd.remove_patch(s_pipeline)
save_dir = args.model_dir
if save_image:
samples = []
@ -1009,12 +1027,14 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
last_prompts.append(prompt)
del samples
del prompts
printm("Patching model with tomesd.")
tomesd.apply_patch(s_pipeline, ratio=0.5)
except Exception as em:
print(f"Exception saving sample: {em}")
traceback.print_exc()
pass
printm("Starting cleanup.")
tomesd.remove_patch(s_pipeline)
del s_pipeline
if save_image:
if "generator" in locals():

View File

@ -986,6 +986,7 @@ def on_ui_tabs():
)
with gr.Tab("Testing", elem_id="TabDebug"):
gr.HTML(value="Experimental Settings")
db_enable_tomesd = gr.Checkbox(label="Enable ToMeSD")
db_disable_logging = gr.Checkbox(label="Disable Logging")
db_deterministic = gr.Checkbox(label="Deterministic")
db_ema_predict = gr.Checkbox(label="Use EMA for prediction")
@ -1236,6 +1237,7 @@ def on_ui_tabs():
db_deterministic,
db_disable_logging,
db_ema_predict,
db_enable_tomesd,
db_epoch_pause_frequency,
db_epoch_pause_time,
db_epochs,