diff --git a/patches/external_pr/dataset.py b/patches/external_pr/dataset.py new file mode 100644 index 0000000..16d0038 --- /dev/null +++ b/patches/external_pr/dataset.py @@ -0,0 +1,190 @@ +# source:https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4886/files + +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +from ..hnutil import get_closest +import random +import tqdm +from modules import devices, shared +import re + +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +re_numbers_at_start = re.compile(r"^[-\d]+\s*") + + +class DatasetEntry: + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, + cond_text=None, pixel_values=None): + self.filename = filename + self.filename_text = filename_text + self.latent_dist = latent_dist + self.latent_sample = latent_sample + self.cond = cond + self.cond_text = cond_text + self.pixel_values = pixel_values + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, + cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, + shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len( + shared.opts.dataset_filename_word_regex) > 0 else None + + self.placeholder_token = placeholder_token + + self.width = width + self.height = height + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + + self.shuffle_tags = shuffle_tags + self.tag_drop_out = tag_drop_out + + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + if shared.state.interrupted: + raise Exception("inturrupted") + try: # apply variable size here + image = Image.open(path).convert('RGB') + w, h = image.size + r = max(1, w / self.width, h / self.height) # divide by this + w, h = int(w/r), int(h/r) + w, h = get_closest(w), get_closest(h) + image = image.resize((w,h), PIL.Image.LANCZOS) + except Exception: + continue + + text_filename = os.path.splitext(path)[0] + ".txt" + filename = os.path.basename(path) + + if os.path.exists(text_filename): + with open(text_filename, "r", encoding="utf8") as file: + filename_text = file.read() + else: + filename_text = os.path.splitext(filename)[0] + filename_text = re.sub(re_numbers_at_start, '', filename_text) + if re_word: + tokens = re_word.findall(filename_text) + filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) + latent_sample = None + + with torch.autocast("cuda"): + latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) + + if latent_sampling_method == "once" or ( + latent_sampling_method == "deterministic" and not isinstance(latent_dist, + DiagonalGaussianDistribution)): + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + latent_sampling_method = "once" + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "deterministic": + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + + if not (self.tag_drop_out != 0 or self.shuffle_tags): + entry.cond_text = self.create_text(filename_text) + + if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): + with torch.autocast("cuda"): + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + + self.dataset.append(entry) + del torchdata + del latent_dist + del latent_sample + + self.length = len(self.dataset) + assert self.length > 0, "No images have been found in the dataset." + self.batch_size = min(batch_size, self.length) + self.gradient_step = min(gradient_step, self.length // self.batch_size) + self.latent_sampling_method = latent_sampling_method + + def create_text(self, filename_text): + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + tags = filename_text.split(',') + if self.tag_drop_out != 0: + tags = [t for t in tags if random.random() > self.tag_drop_out] + if self.shuffle_tags: + random.shuffle(tags) + text = text.replace("[filewords]", ','.join(tags)) + return text + + def __len__(self): + return self.length + + def __getitem__(self, i): + entry = self.dataset[i] + if self.tag_drop_out != 0 or self.shuffle_tags: + entry.cond_text = self.create_text(entry.filename_text) + if self.latent_sampling_method == "random": + entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) + return entry + + +class PersonalizedDataLoader(DataLoader): + def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): + super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, + pin_memory=pin_memory) + if latent_sampling_method == "random": + self.collate_fn = collate_wrapper_random + else: + self.collate_fn = collate_wrapper + + +class BatchLoader: + def __init__(self, data): + self.cond_text = [entry.cond_text for entry in data] + self.cond = [entry.cond for entry in data] + self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + # self.emb_index = [entry.emb_index for entry in data] + # print(self.latent_sample.device) + + def pin_memory(self): + self.latent_sample = self.latent_sample.pin_memory() + return self + + +def collate_wrapper(batch): + return BatchLoader(batch) + + +class BatchLoaderRandom(BatchLoader): + def __init__(self, data): + super().__init__(data) + + def pin_memory(self): + return self + + +def collate_wrapper_random(batch): + return BatchLoaderRandom(batch) \ No newline at end of file diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py new file mode 100644 index 0000000..fa6b0d1 --- /dev/null +++ b/patches/external_pr/hypernetwork.py @@ -0,0 +1,298 @@ +import csv +import datetime +import glob +import html +import os +import sys +import traceback +import inspect + +import torch +import tqdm +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR + +from modules import shared, sd_models, devices, processing, sd_samplers +from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork +from modules.textual_inversion.learn_schedule import LearnRateScheduler +from .textual_inversion import validate_train_inputs, write_loss +from ..hypernetwork import Hypernetwork, load_hypernetwork +from . import sd_hijack_checkpoint +from .dataset import PersonalizedBase,PersonalizedDataLoader + + +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, + training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, + create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, + preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, + preview_width, preview_height, + use_beta_scheduler=False, beta_repeat_epoch=4000, min_lr=1e-7, gamma_rate=1): + # images allows training previews to have infotext. Importing it at the top causes a circular import problem. + from modules import images + try: + beta_repeat_epoch = int(beta_repeat_epoch) + assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!" + min_lr = float(min_lr) + assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!" + gamma_rate = float(gamma_rate) + assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!" + except ValueError: + raise RuntimeError("Cannot use advanced LR scheduler settings!") + save_hypernetwork_every = save_hypernetwork_every or 0 + create_image_every = create_image_every or 0 + validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, + template_file, steps, save_hypernetwork_every, create_image_every, + log_directory, name="hypernetwork") + + load_hypernetwork(hypernetwork_name) + assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!" + if not isinstance(shared.loaded_hypernetwork, Hypernetwork): + raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + unload = shared.opts.unload_models_when_training + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + hypernetwork = shared.loaded_hypernetwork + checkpoint = sd_models.select_checkpoint() + + initial_step = hypernetwork.step or 0 + if initial_step >= steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return hypernetwork, filename + + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + + pin_memory = shared.opts.pin_memory + + ds = PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=hypernetwork_name, model=shared.sd_model, + cond_model=shared.sd_model.cond_stage_model, + device=devices.device, template_file=template_file, + include_cond=True, batch_size=batch_size, + gradient_step=gradient_step, shuffle_tags=shuffle_tags, + tag_drop_out=tag_drop_out, + latent_sampling_method=latent_sampling_method) + + latent_sampling_method = ds.latent_sampling_method + + dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, + batch_size=ds.batch_size, pin_memory=pin_memory) + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + weights = hypernetwork.weights(True) + + # Here we use optimizer from saved HN, or we can specify as UI option. + if hypernetwork.optimizer_name in optimizer_dict: + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) + optimizer_name = hypernetwork.optimizer_name + else: + print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' + + if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. + try: + optimizer.load_state_dict(hypernetwork.optimizer_state_dict) + except RuntimeError as e: + print("Cannot resume from saved optimizer!") + print(e) + scheduler_beta = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=beta_repeat_epoch, T_mult=1, eta_min=min_lr) + + scheduler_gamma = ExponentialLR(optimizer=optimizer, gamma=gamma_rate) + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 # internal + # size = len(ds.indexes) + # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + # losses = torch.zeros((size,)) + # previous_mean_losses = [0] + # previous_mean_loss = 0 + # print("Mean loss of {} elements".format(size)) + + steps_without_grad = 0 + + last_saved_file = "" + last_saved_image = "" + forced_filename = "" + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps - initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + if use_beta_scheduler: + scheduler_beta.step(hypernetwork.step) + scheduler_gamma.step(hypernetwork.step) + else: + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if tag_drop_out != 0 or shuffle_tags: + shared.sd_model.cond_stage_model.to(devices.device) + c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, + non_blocking=pin_memory) + shared.sd_model.cond_stage_model.to(devices.cpu) + else: + c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + del c + + _loss_step += loss.item() + scaler.scale(loss).backward() + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") + # scaler.unscale_(optimizer) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + scaler.step(optimizer) + scaler.update() + hypernetwork.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 + + steps_done = hypernetwork.step + 1 + + epoch_num = hypernetwork.step // steps_per_epoch + epoch_step = hypernetwork.step % steps_per_epoch + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + + write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, + { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + hypernetwork.eval() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + hypernetwork.train() + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, + shared.opts.samples_format, + processed.infotexts[0], p=p, + forced_filename=forced_filename, + save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved hypernetwork: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + except Exception: + print(traceback.format_exc(), file=sys.stderr) + finally: + pbar.leave = False + pbar.close() + hypernetwork.eval() + # report_statistics(loss_dict) + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) + del optimizer + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + return hypernetwork, filename \ No newline at end of file diff --git a/patches/external_pr/sd_hijack_checkpoint.py b/patches/external_pr/sd_hijack_checkpoint.py new file mode 100644 index 0000000..6564743 --- /dev/null +++ b/patches/external_pr/sd_hijack_checkpoint.py @@ -0,0 +1,22 @@ +from torch.utils.checkpoint import checkpoint + + +def BasicTransformerBlock_forward(self, x, context=None): + return checkpoint(self._forward, x, context) + +def AttentionBlock_forward(self, x): + return checkpoint(self._forward, x) + +def ResBlock_forward(self, x, emb): + return checkpoint(self._forward, x, emb) + + +try: + import ldm.modules.attention + import ldm.modules.diffusionmodules.model + import ldm.modules.diffusionmodules.openaimodel + ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward +except: + pass diff --git a/patches/external_pr/textual_inversion.py b/patches/external_pr/textual_inversion.py new file mode 100644 index 0000000..2693928 --- /dev/null +++ b/patches/external_pr/textual_inversion.py @@ -0,0 +1,332 @@ +import csv +import datetime +import html +import os +import sys +import traceback + +import torch +import tqdm +from PIL import PngImagePlugin + +from .dataset import PersonalizedBase, PersonalizedDataLoader +from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack +from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64 +from modules.textual_inversion.learn_schedule import LearnRateScheduler +from modules.textual_inversion.textual_inversion import save_embedding + +#apply OsError avoid here +delayed_values = {} + +def write_loss(log_directory, filename, step, epoch_len, values): + if shared.opts.training_write_csv_every == 0: + return + + if step % shared.opts.training_write_csv_every != 0: + return + write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True + try: + with open(os.path.join(log_directory, filename), "a+", newline='') as fout: + csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) + + if write_csv_header: + csv_writer.writeheader() + if log_directory + filename in delayed_values: + delayed = delayed_values[log_directory + filename] + for step, epoch, epoch_step, values in delayed: + csv_writer.writerow({ + "step": step, + "epoch": epoch, + "epoch_step": epoch_step, + **values, + }) + delayed.clear() + epoch, epoch_step = divmod(step - 1, epoch_len) + csv_writer.writerow({ + "step": step, + "epoch": epoch, + "epoch_step": epoch_step, + **values, + }) + except OSError: + epoch, epoch_step = divmod(step-1, epoch_len) + if log_directory + filename in delayed_values: + delayed_values[log_directory + filename].append((step , epoch, epoch_step, values)) + else: + delayed_values[log_directory + filename] = [(step, epoch, epoch_step, values)] + + +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, + save_model_every, create_image_every, log_directory, name="embedding"): + assert model_name, f"{name} not selected" + assert learn_rate, "Learning rate is empty or 0" + assert isinstance(batch_size, int), "Batch size must be integer" + assert batch_size > 0, "Batch size must be positive" + assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" + assert gradient_step > 0, "Gradient accumulation step must be positive" + assert data_root, "Dataset directory is empty" + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + assert template_file, "Prompt template file is empty" + assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert steps, "Max steps is empty or 0" + assert isinstance(steps, int), "Max steps must be integer" + assert steps > 0, "Max steps must be positive" + assert isinstance(save_model_every, int), "Save {name} must be integer" + assert save_model_every >= 0, "Save {name} must be positive or 0" + assert isinstance(create_image_every, int), "Create image must be integer" + assert create_image_every >= 0, "Create image must be positive or 0" + if save_model_every or create_image_every: + assert log_directory, "Log directory is empty" + + +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, + training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, + save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, + preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, + preview_seed, preview_width, preview_height): + save_embedding_every = save_embedding_every or 0 + create_image_every = create_image_every or 0 + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, + save_embedding_every, create_image_every, log_directory, name="embedding") + + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) + unload = shared.opts.unload_models_when_training + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() + + initial_step = embedding.step or 0 + if initial_step >= steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + + pin_memory = shared.opts.pin_memory + + ds = PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=embedding_name, model=shared.sd_model, + cond_model=shared.sd_model.cond_stage_model, + device=devices.device, template_file=template_file, + batch_size=batch_size, gradient_step=gradient_step, + shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, + latent_sampling_method=latent_sampling_method) + + latent_sampling_method = ds.latent_sampling_method + + dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, + batch_size=ds.batch_size, pin_memory=pin_memory) + + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) + + embedding.vec.requires_grad = True + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 # internal + + last_saved_file = "" + last_saved_image = "" + forced_filename = "" + embedding_yet_to_be_embedded = False + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps - initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + # c = stack_conds(batch.cond).to(devices.device) + # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) + # print(mask) + # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory) + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + + _loss_step += loss.item() + scaler.scale(loss).backward() + + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + scaler.step(optimizer) + scaler.update() + embedding.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // steps_per_epoch + epoch_step = embedding.step % steps_per_epoch + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + # if shared.opts.save_optimizer_state: + # embedding.optimizer_state_dict = optimizer.state_dict() + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, + remove_cached_checksum=True) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None + + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, + shared.opts.samples_format, + processed.infotexts[0], p=p, + forced_filename=forced_filename, + save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + if save_image_with_stored_embedding and os.path.exists( + last_saved_file) and embedding_yet_to_be_embedded: + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) + + title = "<{}>".format(data.get('name', '???')) + + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' + + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) + + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False + + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, + shared.opts.samples_format, + processed.infotexts[0], p=p, + forced_filename=forced_filename, + save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + pass + finally: + pbar.leave = False + pbar.close() + shared.sd_model.first_stage_model.to(devices.device) + + return embedding, filename \ No newline at end of file diff --git a/patches/external_pr/ui.py b/patches/external_pr/ui.py new file mode 100644 index 0000000..22fd13b --- /dev/null +++ b/patches/external_pr/ui.py @@ -0,0 +1,175 @@ +import html +import os + +from modules import shared, sd_hijack, devices +from modules.paths import script_path +from modules.ui import create_refresh_button, gr_show +from webui import wrap_gradio_gpu_call +from .textual_inversion import train_embedding as train_embedding_external +from .hypernetwork import train_hypernetwork as train_hypernetwork_external +import gradio as gr + + +def train_hypernetwork_ui(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = train_hypernetwork_external(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + sd_hijack.apply_optimizations() + + +def on_train_gamma_tab(params=None): + with gr.Tab(label="Train_Gamma") as train_gamma: + gr.HTML( + value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with gr.Row(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted( + sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: { + "choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, + "refresh_train_embedding_name") + with gr.Row(): + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", + choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, + lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, + "refresh_train_hypernetwork_name") + with gr.Row(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', + placeholder="Embedding Learning rate", value="0.005") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', + placeholder="Hypernetwork Learning rate", value="0.00001") + use_beta_scheduler_checkbox = gr.Checkbox( + label='Show advanced learn rate scheduler options(for Hypernetworks)') + with gr.Row(visible=False) as beta_scheduler_options: + use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler') + beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000") + min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler', + placeholder="restricts decay value, but does not restrict gamma rate decay", + value="1e-7") + gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR', + placeholder="Value should be in (0-1]", value="1") + use_beta_scheduler_checkbox.change( + fn=lambda show: gr_show(show), + inputs=[use_beta_scheduler_checkbox], + outputs=[beta_scheduler_options], + ) + batch_size = gr.Number(label='Batch size', value=1, precision=0) + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", + value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', + value=os.path.join(script_path, "textual_inversion_templates", + "style_filewords.txt")) + training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', + value=500, precision=0) + save_embedding_every = gr.Number( + label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + preview_from_txt2img = gr.Checkbox( + label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + with gr.Row(): + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", + value=0) + with gr.Row(): + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", + choices=['once', 'deterministic', 'random']) + + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') + train_embedding = gr.Button(value="Train Embedding", variant='primary') + ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False) + ti_outcome = gr.HTML(elem_id="ti_error3", value="") + + + train_embedding.click( + fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *params.txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(train_hypernetwork_ui, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *params.txt2img_preview_params, + use_beta_scheduler, + beta_repeat_epoch, + min_lr, + gamma_rate + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + return [(train_gamma, "Train Gamma", "train_gamma")] diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index 6f22a93..0cb849f 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -180,6 +180,13 @@ class Hypernetwork: for layer in layers: layer.eval() + + def train(self): + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + + def save(self, filename): state_dict = {} optimizer_saved_dict = {} diff --git a/patches/textual_inversion.py b/patches/textual_inversion.py index 798c43b..f77d600 100644 --- a/patches/textual_inversion.py +++ b/patches/textual_inversion.py @@ -45,4 +45,4 @@ def write_loss(log_directory, filename, step, epoch_len, values): else: delayed_values[log_directory + filename] = [(step+1, epoch, epoch_step, values)] -modules.textual_inversion.textual_inversion.write_loss = write_loss +modules.textual_inversion.textual_inversion.write_loss = write_loss \ No newline at end of file diff --git a/scripts/hypernetwork-extensions.py b/scripts/hypernetwork-extensions.py index d085801..41a5096 100644 --- a/scripts/hypernetwork-extensions.py +++ b/scripts/hypernetwork-extensions.py @@ -7,10 +7,13 @@ import gradio as gr from modules.paths import script_path from modules.ui import create_refresh_button, gr_show -import patches.ui as ui import patches.textual_inversion as textual_inversion +import patches.ui as ui +import patches.external_pr.ui as external_patch_ui from webui import wrap_gradio_gpu_call +setattr(shared.opts,'pin_memory', False) + def create_training_tab(params: script_callbacks.UiTrainTabParams = None): with gr.Tab(label="Train_Beta") as train_beta: @@ -149,7 +152,7 @@ def create_extension_tab(params=None): script_callbacks.on_ui_train_tabs(create_training_tab) script_callbacks.on_ui_train_tabs(create_extension_tab) - +script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_gamma_tab) class Script(scripts.Script): def title(self):