Merge pull request #1 from aria1th/beta-apply-gradient-acc
Add external prbeta-apply-bigger-batch-sizes
commit
fcd33cfd2e
|
|
@ -0,0 +1,190 @@
|
||||||
|
# source:https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4886/files
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from ..hnutil import get_closest
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
from modules import devices, shared
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetEntry:
|
||||||
|
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None,
|
||||||
|
cond_text=None, pixel_values=None):
|
||||||
|
self.filename = filename
|
||||||
|
self.filename_text = filename_text
|
||||||
|
self.latent_dist = latent_dist
|
||||||
|
self.latent_sample = latent_sample
|
||||||
|
self.cond = cond
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.pixel_values = pixel_values
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None,
|
||||||
|
cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1,
|
||||||
|
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||||
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(
|
||||||
|
shared.opts.dataset_filename_word_regex) > 0 else None
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
self.dataset = []
|
||||||
|
|
||||||
|
with open(template_file, "r") as file:
|
||||||
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
|
self.lines = lines
|
||||||
|
|
||||||
|
assert data_root, 'dataset directory not specified'
|
||||||
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
|
||||||
|
self.shuffle_tags = shuffle_tags
|
||||||
|
self.tag_drop_out = tag_drop_out
|
||||||
|
|
||||||
|
print("Preparing dataset...")
|
||||||
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
if shared.state.interrupted:
|
||||||
|
raise Exception("inturrupted")
|
||||||
|
try: # apply variable size here
|
||||||
|
image = Image.open(path).convert('RGB')
|
||||||
|
w, h = image.size
|
||||||
|
r = max(1, w / self.width, h / self.height) # divide by this
|
||||||
|
w, h = int(w/r), int(h/r)
|
||||||
|
w, h = get_closest(w), get_closest(h)
|
||||||
|
image = image.resize((w,h), PIL.Image.LANCZOS)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
text_filename = os.path.splitext(path)[0] + ".txt"
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
|
if os.path.exists(text_filename):
|
||||||
|
with open(text_filename, "r", encoding="utf8") as file:
|
||||||
|
filename_text = file.read()
|
||||||
|
else:
|
||||||
|
filename_text = os.path.splitext(filename)[0]
|
||||||
|
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||||
|
if re_word:
|
||||||
|
tokens = re_word.findall(filename_text)
|
||||||
|
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||||
|
|
||||||
|
npimage = np.array(image).astype(np.uint8)
|
||||||
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||||
|
latent_sample = None
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||||
|
|
||||||
|
if latent_sampling_method == "once" or (
|
||||||
|
latent_sampling_method == "deterministic" and not isinstance(latent_dist,
|
||||||
|
DiagonalGaussianDistribution)):
|
||||||
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
|
latent_sampling_method = "once"
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||||
|
elif latent_sampling_method == "deterministic":
|
||||||
|
# Works only for DiagonalGaussianDistribution
|
||||||
|
latent_dist.std = 0
|
||||||
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||||
|
elif latent_sampling_method == "random":
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||||
|
|
||||||
|
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
|
entry.cond_text = self.create_text(filename_text)
|
||||||
|
|
||||||
|
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
|
||||||
|
self.dataset.append(entry)
|
||||||
|
del torchdata
|
||||||
|
del latent_dist
|
||||||
|
del latent_sample
|
||||||
|
|
||||||
|
self.length = len(self.dataset)
|
||||||
|
assert self.length > 0, "No images have been found in the dataset."
|
||||||
|
self.batch_size = min(batch_size, self.length)
|
||||||
|
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||||
|
self.latent_sampling_method = latent_sampling_method
|
||||||
|
|
||||||
|
def create_text(self, filename_text):
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
tags = filename_text.split(',')
|
||||||
|
if self.tag_drop_out != 0:
|
||||||
|
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||||
|
if self.shuffle_tags:
|
||||||
|
random.shuffle(tags)
|
||||||
|
text = text.replace("[filewords]", ','.join(tags))
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
entry = self.dataset[i]
|
||||||
|
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||||
|
entry.cond_text = self.create_text(entry.filename_text)
|
||||||
|
if self.latent_sampling_method == "random":
|
||||||
|
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedDataLoader(DataLoader):
|
||||||
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
|
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
if latent_sampling_method == "random":
|
||||||
|
self.collate_fn = collate_wrapper_random
|
||||||
|
else:
|
||||||
|
self.collate_fn = collate_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class BatchLoader:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.cond_text = [entry.cond_text for entry in data]
|
||||||
|
self.cond = [entry.cond for entry in data]
|
||||||
|
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||||
|
# self.emb_index = [entry.emb_index for entry in data]
|
||||||
|
# print(self.latent_sample.device)
|
||||||
|
|
||||||
|
def pin_memory(self):
|
||||||
|
self.latent_sample = self.latent_sample.pin_memory()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def collate_wrapper(batch):
|
||||||
|
return BatchLoader(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchLoaderRandom(BatchLoader):
|
||||||
|
def __init__(self, data):
|
||||||
|
super().__init__(data)
|
||||||
|
|
||||||
|
def pin_memory(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def collate_wrapper_random(batch):
|
||||||
|
return BatchLoaderRandom(batch)
|
||||||
|
|
@ -0,0 +1,298 @@
|
||||||
|
import csv
|
||||||
|
import datetime
|
||||||
|
import glob
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR
|
||||||
|
|
||||||
|
from modules import shared, sd_models, devices, processing, sd_samplers
|
||||||
|
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork
|
||||||
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
from .textual_inversion import validate_train_inputs, write_loss
|
||||||
|
from ..hypernetwork import Hypernetwork, load_hypernetwork
|
||||||
|
from . import sd_hijack_checkpoint
|
||||||
|
from .dataset import PersonalizedBase,PersonalizedDataLoader
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
|
||||||
|
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
|
||||||
|
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
|
||||||
|
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed,
|
||||||
|
preview_width, preview_height,
|
||||||
|
use_beta_scheduler=False, beta_repeat_epoch=4000, min_lr=1e-7, gamma_rate=1):
|
||||||
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
from modules import images
|
||||||
|
try:
|
||||||
|
beta_repeat_epoch = int(beta_repeat_epoch)
|
||||||
|
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
|
||||||
|
min_lr = float(min_lr)
|
||||||
|
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
|
||||||
|
gamma_rate = float(gamma_rate)
|
||||||
|
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError("Cannot use advanced LR scheduler settings!")
|
||||||
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
|
create_image_every = create_image_every or 0
|
||||||
|
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||||
|
template_file, steps, save_hypernetwork_every, create_image_every,
|
||||||
|
log_directory, name="hypernetwork")
|
||||||
|
|
||||||
|
load_hypernetwork(hypernetwork_name)
|
||||||
|
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
|
||||||
|
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
|
||||||
|
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
|
if save_hypernetwork_every > 0:
|
||||||
|
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||||
|
os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
hypernetwork_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
hypernetwork = shared.loaded_hypernetwork
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
initial_step = hypernetwork.step or 0
|
||||||
|
if initial_step >= steps:
|
||||||
|
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||||
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
|
||||||
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
|
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||||
|
height=training_height,
|
||||||
|
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||||
|
placeholder_token=hypernetwork_name, model=shared.sd_model,
|
||||||
|
cond_model=shared.sd_model.cond_stage_model,
|
||||||
|
device=devices.device, template_file=template_file,
|
||||||
|
include_cond=True, batch_size=batch_size,
|
||||||
|
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
|
||||||
|
tag_drop_out=tag_drop_out,
|
||||||
|
latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
|
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
|
||||||
|
batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
|
weights = hypernetwork.weights(True)
|
||||||
|
|
||||||
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||||
|
if hypernetwork.optimizer_name in optimizer_dict:
|
||||||
|
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||||
|
optimizer_name = hypernetwork.optimizer_name
|
||||||
|
else:
|
||||||
|
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||||
|
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||||
|
optimizer_name = 'AdamW'
|
||||||
|
|
||||||
|
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||||
|
try:
|
||||||
|
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print("Cannot resume from saved optimizer!")
|
||||||
|
print(e)
|
||||||
|
scheduler_beta = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=beta_repeat_epoch, T_mult=1, eta_min=min_lr)
|
||||||
|
|
||||||
|
scheduler_gamma = ExponentialLR(optimizer=optimizer, gamma=gamma_rate)
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
batch_size = ds.batch_size
|
||||||
|
gradient_step = ds.gradient_step
|
||||||
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||||
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||||
|
loss_step = 0
|
||||||
|
_loss_step = 0 # internal
|
||||||
|
# size = len(ds.indexes)
|
||||||
|
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||||
|
# losses = torch.zeros((size,))
|
||||||
|
# previous_mean_losses = [0]
|
||||||
|
# previous_mean_loss = 0
|
||||||
|
# print("Mean loss of {} elements".format(size))
|
||||||
|
|
||||||
|
steps_without_grad = 0
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
forced_filename = "<none>"
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
|
try:
|
||||||
|
for i in range((steps - initial_step) * gradient_step):
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
for j, batch in enumerate(dl):
|
||||||
|
# works as a drop_last=True for gradient accumulation
|
||||||
|
if j == max_steps_per_epoch:
|
||||||
|
break
|
||||||
|
if use_beta_scheduler:
|
||||||
|
scheduler_beta.step(hypernetwork.step)
|
||||||
|
scheduler_gamma.step(hypernetwork.step)
|
||||||
|
else:
|
||||||
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device,
|
||||||
|
non_blocking=pin_memory)
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
else:
|
||||||
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||||
|
del x
|
||||||
|
del c
|
||||||
|
|
||||||
|
_loss_step += loss.item()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
# go back until we reach gradient accumulation steps
|
||||||
|
if (j + 1) % gradient_step != 0:
|
||||||
|
continue
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||||
|
# scaler.unscale_(optimizer)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||||
|
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
hypernetwork.step += 1
|
||||||
|
pbar.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss_step = _loss_step
|
||||||
|
_loss_step = 0
|
||||||
|
|
||||||
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||||
|
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||||
|
# Before saving, change name to match current checkpoint.
|
||||||
|
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||||
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
|
||||||
|
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch,
|
||||||
|
{
|
||||||
|
"loss": f"{loss_step:.7f}",
|
||||||
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
hypernetwork.eval()
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = batch.cond_text[0]
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
|
preview_text = p.prompt
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
hypernetwork.train()
|
||||||
|
if image is not None:
|
||||||
|
shared.state.current_image = image
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
||||||
|
shared.opts.samples_format,
|
||||||
|
processed.infotexts[0], p=p,
|
||||||
|
forced_filename=forced_filename,
|
||||||
|
save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {loss_step:.7f}<br/>
|
||||||
|
Step: {steps_done}<br/>
|
||||||
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||||
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
except Exception:
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
finally:
|
||||||
|
pbar.leave = False
|
||||||
|
pbar.close()
|
||||||
|
hypernetwork.eval()
|
||||||
|
# report_statistics(loss_dict)
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||||
|
del optimizer
|
||||||
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
return hypernetwork, filename
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def BasicTransformerBlock_forward(self, x, context=None):
|
||||||
|
return checkpoint(self._forward, x, context)
|
||||||
|
|
||||||
|
def AttentionBlock_forward(self, x):
|
||||||
|
return checkpoint(self._forward, x)
|
||||||
|
|
||||||
|
def ResBlock_forward(self, x, emb):
|
||||||
|
return checkpoint(self._forward, x, emb)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ldm.modules.attention
|
||||||
|
import ldm.modules.diffusionmodules.model
|
||||||
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,332 @@
|
||||||
|
import csv
|
||||||
|
import datetime
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from PIL import PngImagePlugin
|
||||||
|
|
||||||
|
from .dataset import PersonalizedBase, PersonalizedDataLoader
|
||||||
|
from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack
|
||||||
|
from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64
|
||||||
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
from modules.textual_inversion.textual_inversion import save_embedding
|
||||||
|
|
||||||
|
#apply OsError avoid here
|
||||||
|
delayed_values = {}
|
||||||
|
|
||||||
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
|
if shared.opts.training_write_csv_every == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if step % shared.opts.training_write_csv_every != 0:
|
||||||
|
return
|
||||||
|
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||||
|
try:
|
||||||
|
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
||||||
|
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
|
||||||
|
|
||||||
|
if write_csv_header:
|
||||||
|
csv_writer.writeheader()
|
||||||
|
if log_directory + filename in delayed_values:
|
||||||
|
delayed = delayed_values[log_directory + filename]
|
||||||
|
for step, epoch, epoch_step, values in delayed:
|
||||||
|
csv_writer.writerow({
|
||||||
|
"step": step,
|
||||||
|
"epoch": epoch,
|
||||||
|
"epoch_step": epoch_step,
|
||||||
|
**values,
|
||||||
|
})
|
||||||
|
delayed.clear()
|
||||||
|
epoch, epoch_step = divmod(step - 1, epoch_len)
|
||||||
|
csv_writer.writerow({
|
||||||
|
"step": step,
|
||||||
|
"epoch": epoch,
|
||||||
|
"epoch_step": epoch_step,
|
||||||
|
**values,
|
||||||
|
})
|
||||||
|
except OSError:
|
||||||
|
epoch, epoch_step = divmod(step-1, epoch_len)
|
||||||
|
if log_directory + filename in delayed_values:
|
||||||
|
delayed_values[log_directory + filename].append((step , epoch, epoch_step, values))
|
||||||
|
else:
|
||||||
|
delayed_values[log_directory + filename] = [(step, epoch, epoch_step, values)]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps,
|
||||||
|
save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
|
assert model_name, f"{name} not selected"
|
||||||
|
assert learn_rate, "Learning rate is empty or 0"
|
||||||
|
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||||
|
assert batch_size > 0, "Batch size must be positive"
|
||||||
|
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
||||||
|
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
||||||
|
assert data_root, "Dataset directory is empty"
|
||||||
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
|
assert template_file, "Prompt template file is empty"
|
||||||
|
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||||
|
assert steps, "Max steps is empty or 0"
|
||||||
|
assert isinstance(steps, int), "Max steps must be integer"
|
||||||
|
assert steps > 0, "Max steps must be positive"
|
||||||
|
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||||
|
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
||||||
|
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||||
|
assert create_image_every >= 0, "Create image must be positive or 0"
|
||||||
|
if save_model_every or create_image_every:
|
||||||
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width,
|
||||||
|
training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every,
|
||||||
|
save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img,
|
||||||
|
preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale,
|
||||||
|
preview_seed, preview_width, preview_height):
|
||||||
|
save_embedding_every = save_embedding_every or 0
|
||||||
|
create_image_every = create_image_every or 0
|
||||||
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps,
|
||||||
|
save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
|
if save_embedding_every > 0:
|
||||||
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
|
os.makedirs(embedding_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
embedding_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0 and save_image_with_stored_embedding:
|
||||||
|
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
||||||
|
os.makedirs(images_embeds_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_embeds_dir = None
|
||||||
|
|
||||||
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
initial_step = embedding.step or 0
|
||||||
|
if initial_step >= steps:
|
||||||
|
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||||
|
return embedding, filename
|
||||||
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
|
||||||
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
|
ds = PersonalizedBase(data_root=data_root, width=training_width,
|
||||||
|
height=training_height,
|
||||||
|
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||||
|
placeholder_token=embedding_name, model=shared.sd_model,
|
||||||
|
cond_model=shared.sd_model.cond_stage_model,
|
||||||
|
device=devices.device, template_file=template_file,
|
||||||
|
batch_size=batch_size, gradient_step=gradient_step,
|
||||||
|
shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out,
|
||||||
|
latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
|
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
|
||||||
|
batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
|
embedding.vec.requires_grad = True
|
||||||
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
batch_size = ds.batch_size
|
||||||
|
gradient_step = ds.gradient_step
|
||||||
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||||
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||||
|
loss_step = 0
|
||||||
|
_loss_step = 0 # internal
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
forced_filename = "<none>"
|
||||||
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
|
try:
|
||||||
|
for i in range((steps - initial_step) * gradient_step):
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
for j, batch in enumerate(dl):
|
||||||
|
# works as a drop_last=True for gradient accumulation
|
||||||
|
if j == max_steps_per_epoch:
|
||||||
|
break
|
||||||
|
scheduler.apply(optimizer, embedding.step)
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
# c = stack_conds(batch.cond).to(devices.device)
|
||||||
|
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||||
|
# print(mask)
|
||||||
|
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
||||||
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||||
|
del x
|
||||||
|
|
||||||
|
_loss_step += loss.item()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# go back until we reach gradient accumulation steps
|
||||||
|
if (j + 1) % gradient_step != 0:
|
||||||
|
continue
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
embedding.step += 1
|
||||||
|
pbar.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss_step = _loss_step
|
||||||
|
_loss_step = 0
|
||||||
|
|
||||||
|
steps_done = embedding.step + 1
|
||||||
|
|
||||||
|
epoch_num = embedding.step // steps_per_epoch
|
||||||
|
epoch_step = embedding.step % steps_per_epoch
|
||||||
|
|
||||||
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||||
|
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||||
|
# Before saving, change name to match current checkpoint.
|
||||||
|
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||||
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||||
|
# if shared.opts.save_optimizer_state:
|
||||||
|
# embedding.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file,
|
||||||
|
remove_cached_checksum=True)
|
||||||
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||||
|
"loss": f"{loss_step:.7f}",
|
||||||
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
|
forced_filename = f'{embedding_name}-{steps_done}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
do_not_reload_embeddings=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = batch.cond_text[0]
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
|
preview_text = p.prompt
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
shared.state.current_image = image
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
||||||
|
shared.opts.samples_format,
|
||||||
|
processed.infotexts[0], p=p,
|
||||||
|
forced_filename=forced_filename,
|
||||||
|
save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
if save_image_with_stored_embedding and os.path.exists(
|
||||||
|
last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
|
||||||
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||||
|
|
||||||
|
info = PngImagePlugin.PngInfo()
|
||||||
|
data = torch.load(last_saved_file)
|
||||||
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||||
|
|
||||||
|
title = "<{}>".format(data.get('name', '???'))
|
||||||
|
|
||||||
|
try:
|
||||||
|
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||||
|
except Exception as e:
|
||||||
|
vectorSize = '?'
|
||||||
|
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
footer_left = checkpoint.model_name
|
||||||
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||||
|
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||||
|
|
||||||
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||||
|
|
||||||
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
||||||
|
shared.opts.samples_format,
|
||||||
|
processed.infotexts[0], p=p,
|
||||||
|
forced_filename=forced_filename,
|
||||||
|
save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {loss_step:.7f}<br/>
|
||||||
|
Step: {steps_done}<br/>
|
||||||
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||||
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
|
except Exception:
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pbar.leave = False
|
||||||
|
pbar.close()
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
|
return embedding, filename
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
|
||||||
|
from modules import shared, sd_hijack, devices
|
||||||
|
from modules.paths import script_path
|
||||||
|
from modules.ui import create_refresh_button, gr_show
|
||||||
|
from webui import wrap_gradio_gpu_call
|
||||||
|
from .textual_inversion import train_embedding as train_embedding_external
|
||||||
|
from .hypernetwork import train_hypernetwork as train_hypernetwork_external
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork_ui(*args):
|
||||||
|
|
||||||
|
initial_hypernetwork = shared.loaded_hypernetwork
|
||||||
|
|
||||||
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
hypernetwork, filename = train_hypernetwork_external(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
||||||
|
Hypernetwork saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
|
||||||
|
def on_train_gamma_tab(params=None):
|
||||||
|
with gr.Tab(label="Train_Gamma") as train_gamma:
|
||||||
|
gr.HTML(
|
||||||
|
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
|
with gr.Row():
|
||||||
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(
|
||||||
|
sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
|
create_refresh_button(train_embedding_name,
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {
|
||||||
|
"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())},
|
||||||
|
"refresh_train_embedding_name")
|
||||||
|
with gr.Row():
|
||||||
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
|
||||||
|
choices=[x for x in shared.hypernetworks.keys()])
|
||||||
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
|
||||||
|
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
|
||||||
|
"refresh_train_hypernetwork_name")
|
||||||
|
with gr.Row():
|
||||||
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate',
|
||||||
|
placeholder="Embedding Learning rate", value="0.005")
|
||||||
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
|
||||||
|
placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
use_beta_scheduler_checkbox = gr.Checkbox(
|
||||||
|
label='Show advanced learn rate scheduler options(for Hypernetworks)')
|
||||||
|
with gr.Row(visible=False) as beta_scheduler_options:
|
||||||
|
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
|
||||||
|
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
|
||||||
|
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
|
||||||
|
placeholder="restricts decay value, but does not restrict gamma rate decay",
|
||||||
|
value="1e-7")
|
||||||
|
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
|
||||||
|
placeholder="Value should be in (0-1]", value="1")
|
||||||
|
use_beta_scheduler_checkbox.change(
|
||||||
|
fn=lambda show: gr_show(show),
|
||||||
|
inputs=[use_beta_scheduler_checkbox],
|
||||||
|
outputs=[beta_scheduler_options],
|
||||||
|
)
|
||||||
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
|
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
|
||||||
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
|
||||||
|
value="textual_inversion")
|
||||||
|
template_file = gr.Textbox(label='Prompt template file',
|
||||||
|
value=os.path.join(script_path, "textual_inversion_templates",
|
||||||
|
"style_filewords.txt"))
|
||||||
|
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable',
|
||||||
|
value=500, precision=0)
|
||||||
|
save_embedding_every = gr.Number(
|
||||||
|
label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
|
preview_from_txt2img = gr.Checkbox(
|
||||||
|
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||||
|
with gr.Row():
|
||||||
|
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
|
||||||
|
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.",
|
||||||
|
value=0)
|
||||||
|
with gr.Row():
|
||||||
|
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once",
|
||||||
|
choices=['once', 'deterministic', 'random'])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||||
|
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||||
|
ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False)
|
||||||
|
ti_outcome = gr.HTML(elem_id="ti_error3", value="")
|
||||||
|
|
||||||
|
|
||||||
|
train_embedding.click(
|
||||||
|
fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
embedding_learn_rate,
|
||||||
|
batch_size,
|
||||||
|
gradient_step,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
training_width,
|
||||||
|
training_height,
|
||||||
|
steps,
|
||||||
|
shuffle_tags,
|
||||||
|
tag_drop_out,
|
||||||
|
latent_sampling_method,
|
||||||
|
create_image_every,
|
||||||
|
save_embedding_every,
|
||||||
|
template_file,
|
||||||
|
save_image_with_stored_embedding,
|
||||||
|
preview_from_txt2img,
|
||||||
|
*params.txt2img_preview_params,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
train_hypernetwork.click(
|
||||||
|
fn=wrap_gradio_gpu_call(train_hypernetwork_ui, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
train_hypernetwork_name,
|
||||||
|
hypernetwork_learn_rate,
|
||||||
|
batch_size,
|
||||||
|
gradient_step,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
training_width,
|
||||||
|
training_height,
|
||||||
|
steps,
|
||||||
|
shuffle_tags,
|
||||||
|
tag_drop_out,
|
||||||
|
latent_sampling_method,
|
||||||
|
create_image_every,
|
||||||
|
save_embedding_every,
|
||||||
|
template_file,
|
||||||
|
preview_from_txt2img,
|
||||||
|
*params.txt2img_preview_params,
|
||||||
|
use_beta_scheduler,
|
||||||
|
beta_repeat_epoch,
|
||||||
|
min_lr,
|
||||||
|
gamma_rate
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt_training.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
return [(train_gamma, "Train Gamma", "train_gamma")]
|
||||||
|
|
@ -180,6 +180,13 @@ class Hypernetwork:
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.eval()
|
layer.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.train()
|
||||||
|
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
optimizer_saved_dict = {}
|
optimizer_saved_dict = {}
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,13 @@ import gradio as gr
|
||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui import create_refresh_button, gr_show
|
from modules.ui import create_refresh_button, gr_show
|
||||||
import patches.ui as ui
|
|
||||||
import patches.textual_inversion as textual_inversion
|
import patches.textual_inversion as textual_inversion
|
||||||
|
import patches.ui as ui
|
||||||
|
import patches.external_pr.ui as external_patch_ui
|
||||||
from webui import wrap_gradio_gpu_call
|
from webui import wrap_gradio_gpu_call
|
||||||
|
|
||||||
|
setattr(shared.opts,'pin_memory', False)
|
||||||
|
|
||||||
|
|
||||||
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
|
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
|
||||||
with gr.Tab(label="Train_Beta") as train_beta:
|
with gr.Tab(label="Train_Beta") as train_beta:
|
||||||
|
|
@ -149,7 +152,7 @@ def create_extension_tab(params=None):
|
||||||
|
|
||||||
script_callbacks.on_ui_train_tabs(create_training_tab)
|
script_callbacks.on_ui_train_tabs(create_training_tab)
|
||||||
script_callbacks.on_ui_train_tabs(create_extension_tab)
|
script_callbacks.on_ui_train_tabs(create_extension_tab)
|
||||||
|
script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_gamma_tab)
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def title(self):
|
def title(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue