From e636a8353d3b003eb13e29b3d1984a57609d9737 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Thu, 19 Jan 2023 14:29:48 +0900 Subject: [PATCH] Fixes for most recent webui with ensuring compatibility for previous webui --- patches/clip_hijack.py | 20 +++++++++++++++++--- patches/external_pr/hypernetwork.py | 10 ++++++++-- patches/external_pr/textual_inversion.py | 6 +++++- patches/hypernetwork.py | 2 +- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/patches/clip_hijack.py b/patches/clip_hijack.py index 3156b25..a826ca4 100644 --- a/patches/clip_hijack.py +++ b/patches/clip_hijack.py @@ -1,5 +1,19 @@ from modules import sd_hijack_clip, sd_hijack, shared -from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations, fix_checkpoint +from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations +try: + from modules.sd_hijack import fix_checkpoint + def clear_any_hijacks(): + StableDiffusionModelHijack.hijack = default_hijack +except (ModuleNotFoundError, ImportError): + from modules.sd_hijack_checkpoint import add, remove + def fix_checkpoint(): + add() + + def clear_any_hijacks(): + remove() + StableDiffusionModelHijack.hijack = default_hijack + + import ldm.modules.encoders.modules default_hijack = StableDiffusionModelHijack.hijack @@ -16,8 +30,7 @@ def trigger_sd_hijack(enabled, pretrained_key): StableDiffusionModelHijack.hijack = default_hijack -def clear_any_hijacks(): - StableDiffusionModelHijack.hijack = default_hijack + def create_lambda(model): def hijack_lambda(self, m): @@ -43,6 +56,7 @@ def create_lambda(model): fix_checkpoint() + def flatten(el): flattened = [flatten(children) for children in el.children()] res = [el] diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 7c45aa7..ebde4aa 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -291,7 +291,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi last_saved_file = "" last_saved_image = "" forced_filename = "" - + if hasattr(sd_hijack_checkpoint, 'add'): + sd_hijack_checkpoint.add() pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps - initial_step) * gradient_step): @@ -466,6 +467,8 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() hypernetwork.eval() shared.parallel_processing_allowed = old_parallel_processing_allowed + if hasattr(sd_hijack_checkpoint, 'remove'): + sd_hijack_checkpoint.remove() report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name @@ -755,7 +758,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, last_saved_file = "" last_saved_image = "" forced_filename = "" - + if hasattr(sd_hijack_checkpoint, 'add'): + sd_hijack_checkpoint.add() pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps - initial_step) * gradient_step): @@ -930,6 +934,8 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() hypernetwork.eval() shared.parallel_processing_allowed = old_parallel_processing_allowed + if hasattr(sd_hijack_checkpoint, 'remove'): + sd_hijack_checkpoint.remove() if shared.opts.training_enable_tensorboard: tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate, GA_steps=gradient_step, diff --git a/patches/external_pr/textual_inversion.py b/patches/external_pr/textual_inversion.py index 34b079a..ae5f019 100644 --- a/patches/external_pr/textual_inversion.py +++ b/patches/external_pr/textual_inversion.py @@ -14,7 +14,7 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader from ..scheduler import CosineAnnealingWarmUpRestarts from ..hnutil import optim_to -from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack +from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack, sd_hijack_checkpoint from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64 from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.textual_inversion import save_embedding @@ -299,6 +299,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st img_c = None pbar = tqdm.tqdm(total=steps - initial_step) + if hasattr(sd_hijack_checkpoint, 'add'): + sd_hijack_checkpoint.add() try: for i in range((steps - initial_step) * gradient_step): if scheduler.finished: @@ -490,4 +492,6 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() shared.sd_model.first_stage_model.to(devices.device) shared.parallel_processing_allowed = old_parallel_processing_allowed + if hasattr(sd_hijack_checkpoint, 'remove'): + sd_hijack_checkpoint.remove() return embedding, filename diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index 8baa063..cbb0fc5 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -16,7 +16,7 @@ import scripts.xy_grid from modules.shared import opts try: from modules.hashes import sha256 -except ImportError or ModuleNotFoundError: +except (ImportError, ModuleNotFoundError): print("modules.hashes is not found, will use backup module from extension!") from .hashes_backup import sha256