Add ToMESD support - enabled under Testing tab
parent
bcc266bf29
commit
8c932bea0a
|
|
@ -60,3 +60,4 @@ docs/_build/
|
|||
|
||||
# PyBuilder
|
||||
target/
|
||||
.vscode/launch.json
|
||||
|
|
|
|||
|
|
@ -33,10 +33,11 @@ class DreamboothConfig(BaseModel):
|
|||
deterministic: bool = False
|
||||
disable_logging: bool = False
|
||||
ema_predict: bool = False
|
||||
enable_tomesd: bool = False
|
||||
epoch: int = 0
|
||||
epoch_pause_frequency: int = 0
|
||||
epoch_pause_time: int = 0
|
||||
freeze_clip_normalization: bool = True
|
||||
freeze_clip_normalization: bool = False
|
||||
gradient_accumulation_steps: int = 1
|
||||
gradient_checkpointing: bool = True
|
||||
gradient_set_to_none: bool = True
|
||||
|
|
@ -105,7 +106,7 @@ class DreamboothConfig(BaseModel):
|
|||
src: str = ""
|
||||
stop_text_encoder: float = 1.0
|
||||
strict_tokens: bool = False
|
||||
tenc_weight_decay: float = 0.00
|
||||
tenc_weight_decay: float = 0.01
|
||||
tenc_grad_clip_norm: float = 0.00
|
||||
tf32_enable: bool = False
|
||||
train_batch_size: int = 1
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import torch
|
|||
import torch.backends.cuda
|
||||
import torch.backends.cudnn
|
||||
import torch.utils.checkpoint
|
||||
import tomesd
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils.random import set_seed as set_seed2
|
||||
from diffusers import (
|
||||
|
|
@ -77,6 +78,8 @@ dl.set_verbosity_error()
|
|||
last_samples = []
|
||||
last_prompts = []
|
||||
|
||||
|
||||
|
||||
try:
|
||||
diff_version = importlib_metadata.version("diffusers")
|
||||
version_string = diff_version.split(".")
|
||||
|
|
@ -163,6 +166,9 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
|
|||
result = TrainResult
|
||||
result.config = args
|
||||
|
||||
enable_tomesd = args.enable_tomesd
|
||||
enable_tomesd = True
|
||||
|
||||
set_seed(args.deterministic)
|
||||
|
||||
@find_executable_batch_size(
|
||||
|
|
@ -844,8 +850,12 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
|
|||
|
||||
s_pipeline = s_pipeline.to(accelerator.device)
|
||||
|
||||
printm("Patching model with tomesd.")
|
||||
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
||||
|
||||
with accelerator.autocast(), torch.inference_mode():
|
||||
if save_model:
|
||||
tomesd.remove_patch(s_pipeline)
|
||||
# We are saving weights, we need to ensure revision is saved
|
||||
args.save()
|
||||
try:
|
||||
|
|
@ -885,7 +895,11 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
|
|||
)
|
||||
pbar.update()
|
||||
|
||||
printm("Patching model with tomesd.")
|
||||
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
||||
|
||||
elif save_lora:
|
||||
tomesd.remove_patch(s_pipeline)
|
||||
pbar.set_description("Saving Lora Weights...")
|
||||
# setup directory
|
||||
loras_dir = os.path.join(args.model_dir, "loras")
|
||||
|
|
@ -939,11 +953,15 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
|
|||
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file,
|
||||
log=False, snap_rev=snap_rev, pbar=pbar)
|
||||
printm("Restored, moved to acc.device.")
|
||||
|
||||
printm("Patching model with tomesd.")
|
||||
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
||||
|
||||
except Exception as ex:
|
||||
print(f"Exception saving checkpoint/model: {ex}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
tomesd.remove_patch(s_pipeline)
|
||||
save_dir = args.model_dir
|
||||
if save_image:
|
||||
samples = []
|
||||
|
|
@ -1009,12 +1027,14 @@ def main(class_gen_method: str = "Native Diffusers") -> TrainResult:
|
|||
last_prompts.append(prompt)
|
||||
del samples
|
||||
del prompts
|
||||
|
||||
printm("Patching model with tomesd.")
|
||||
tomesd.apply_patch(s_pipeline, ratio=0.5)
|
||||
except Exception as em:
|
||||
print(f"Exception saving sample: {em}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
printm("Starting cleanup.")
|
||||
tomesd.remove_patch(s_pipeline)
|
||||
del s_pipeline
|
||||
if save_image:
|
||||
if "generator" in locals():
|
||||
|
|
|
|||
|
|
@ -986,6 +986,7 @@ def on_ui_tabs():
|
|||
)
|
||||
with gr.Tab("Testing", elem_id="TabDebug"):
|
||||
gr.HTML(value="Experimental Settings")
|
||||
db_enable_tomesd = gr.Checkbox(label="Enable ToMeSD")
|
||||
db_disable_logging = gr.Checkbox(label="Disable Logging")
|
||||
db_deterministic = gr.Checkbox(label="Deterministic")
|
||||
db_ema_predict = gr.Checkbox(label="Use EMA for prediction")
|
||||
|
|
@ -1236,6 +1237,7 @@ def on_ui_tabs():
|
|||
db_deterministic,
|
||||
db_disable_logging,
|
||||
db_ema_predict,
|
||||
db_enable_tomesd,
|
||||
db_epoch_pause_frequency,
|
||||
db_epoch_pause_time,
|
||||
db_epochs,
|
||||
|
|
|
|||
Loading…
Reference in New Issue