Fixes for most recent webui

with ensuring compatibility for previous webui
noise-scheduler
aria1th 2023-01-19 14:29:48 +09:00
parent c216a0fc7f
commit e636a8353d
4 changed files with 31 additions and 7 deletions

View File

@ -1,5 +1,19 @@
from modules import sd_hijack_clip, sd_hijack, shared
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations, fix_checkpoint
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations
try:
from modules.sd_hijack import fix_checkpoint
def clear_any_hijacks():
StableDiffusionModelHijack.hijack = default_hijack
except (ModuleNotFoundError, ImportError):
from modules.sd_hijack_checkpoint import add, remove
def fix_checkpoint():
add()
def clear_any_hijacks():
remove()
StableDiffusionModelHijack.hijack = default_hijack
import ldm.modules.encoders.modules
default_hijack = StableDiffusionModelHijack.hijack
@ -16,8 +30,7 @@ def trigger_sd_hijack(enabled, pretrained_key):
StableDiffusionModelHijack.hijack = default_hijack
def clear_any_hijacks():
StableDiffusionModelHijack.hijack = default_hijack
def create_lambda(model):
def hijack_lambda(self, m):
@ -43,6 +56,7 @@ def create_lambda(model):
fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]

View File

@ -291,7 +291,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
if hasattr(sd_hijack_checkpoint, 'add'):
sd_hijack_checkpoint.add()
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps - initial_step) * gradient_step):
@ -466,6 +467,8 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
hypernetwork.eval()
shared.parallel_processing_allowed = old_parallel_processing_allowed
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
@ -755,7 +758,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
if hasattr(sd_hijack_checkpoint, 'add'):
sd_hijack_checkpoint.add()
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps - initial_step) * gradient_step):
@ -930,6 +934,8 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
hypernetwork.eval()
shared.parallel_processing_allowed = old_parallel_processing_allowed
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
if shared.opts.training_enable_tensorboard:
tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate,
GA_steps=gradient_step,

View File

@ -14,7 +14,7 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader
from ..scheduler import CosineAnnealingWarmUpRestarts
from ..hnutil import optim_to
from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack
from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack, sd_hijack_checkpoint
from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.textual_inversion import save_embedding
@ -299,6 +299,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
img_c = None
pbar = tqdm.tqdm(total=steps - initial_step)
if hasattr(sd_hijack_checkpoint, 'add'):
sd_hijack_checkpoint.add()
try:
for i in range((steps - initial_step) * gradient_step):
if scheduler.finished:
@ -490,4 +492,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
return embedding, filename

View File

@ -16,7 +16,7 @@ import scripts.xy_grid
from modules.shared import opts
try:
from modules.hashes import sha256
except ImportError or ModuleNotFoundError:
except (ImportError, ModuleNotFoundError):
print("modules.hashes is not found, will use backup module from extension!")
from .hashes_backup import sha256