Fixes for most recent webui
with ensuring compatibility for previous webuinoise-scheduler
parent
c216a0fc7f
commit
e636a8353d
|
|
@ -1,5 +1,19 @@
|
|||
from modules import sd_hijack_clip, sd_hijack, shared
|
||||
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations, fix_checkpoint
|
||||
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations
|
||||
try:
|
||||
from modules.sd_hijack import fix_checkpoint
|
||||
def clear_any_hijacks():
|
||||
StableDiffusionModelHijack.hijack = default_hijack
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
from modules.sd_hijack_checkpoint import add, remove
|
||||
def fix_checkpoint():
|
||||
add()
|
||||
|
||||
def clear_any_hijacks():
|
||||
remove()
|
||||
StableDiffusionModelHijack.hijack = default_hijack
|
||||
|
||||
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
default_hijack = StableDiffusionModelHijack.hijack
|
||||
|
|
@ -16,8 +30,7 @@ def trigger_sd_hijack(enabled, pretrained_key):
|
|||
StableDiffusionModelHijack.hijack = default_hijack
|
||||
|
||||
|
||||
def clear_any_hijacks():
|
||||
StableDiffusionModelHijack.hijack = default_hijack
|
||||
|
||||
|
||||
def create_lambda(model):
|
||||
def hijack_lambda(self, m):
|
||||
|
|
@ -43,6 +56,7 @@ def create_lambda(model):
|
|||
|
||||
fix_checkpoint()
|
||||
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
res = [el]
|
||||
|
|
|
|||
|
|
@ -291,7 +291,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
if hasattr(sd_hijack_checkpoint, 'add'):
|
||||
sd_hijack_checkpoint.add()
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps - initial_step) * gradient_step):
|
||||
|
|
@ -466,6 +467,8 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
report_statistics(loss_dict)
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
|
|
@ -755,7 +758,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
if hasattr(sd_hijack_checkpoint, 'add'):
|
||||
sd_hijack_checkpoint.add()
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps - initial_step) * gradient_step):
|
||||
|
|
@ -930,6 +934,8 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate,
|
||||
GA_steps=gradient_step,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .dataset import PersonalizedBase, PersonalizedDataLoader
|
|||
from ..scheduler import CosineAnnealingWarmUpRestarts
|
||||
from ..hnutil import optim_to
|
||||
|
||||
from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack
|
||||
from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack, sd_hijack_checkpoint
|
||||
from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from modules.textual_inversion.textual_inversion import save_embedding
|
||||
|
|
@ -299,6 +299,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
img_c = None
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
if hasattr(sd_hijack_checkpoint, 'add'):
|
||||
sd_hijack_checkpoint.add()
|
||||
try:
|
||||
for i in range((steps - initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
|
|
@ -490,4 +492,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
return embedding, filename
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import scripts.xy_grid
|
|||
from modules.shared import opts
|
||||
try:
|
||||
from modules.hashes import sha256
|
||||
except ImportError or ModuleNotFoundError:
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("modules.hashes is not found, will use backup module from extension!")
|
||||
from .hashes_backup import sha256
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue