Merge pull request #1 from aria1th/beta-apply-gradient-acc

Add external pr
beta-apply-bigger-batch-sizes
AngelBottomless 2022-11-23 23:38:20 +09:00 committed by GitHub
commit fcd33cfd2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1030 additions and 3 deletions

View File

@ -0,0 +1,190 @@
# source:https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4886/files
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from ..hnutil import get_closest
import random
import tqdm
from modules import devices, shared
import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None,
cond_text=None, pixel_values=None):
self.filename = filename
self.filename_text = filename_text
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None,
cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1,
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(
shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
try: # apply variable size here
image = Image.open(path).convert('RGB')
w, h = image.size
r = max(1, w / self.width, h / self.height) # divide by this
w, h = int(w/r), int(h/r)
w, h = get_closest(w), get_closest(h)
image = image.resize((w,h), PIL.Image.LANCZOS)
except Exception:
continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
latent_sample = None
with torch.autocast("cuda"):
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
if latent_sampling_method == "once" or (
latent_sampling_method == "deterministic" and not isinstance(latent_dist,
DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with torch.autocast("cuda"):
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
self.length = len(self.dataset)
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
tags = filename_text.split(',')
if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
return text
def __len__(self):
return self.length
def __getitem__(self, i):
entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size,
pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
class BatchLoader:
def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
# self.emb_index = [entry.emb_index for entry in data]
# print(self.latent_sample.device)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)
class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
def pin_memory(self):
return self
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)

View File

@ -0,0 +1,298 @@
import csv
import datetime
import glob
import html
import os
import sys
import traceback
import inspect
import torch
import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR
from modules import shared, sd_models, devices, processing, sd_samplers
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from .textual_inversion import validate_train_inputs, write_loss
from ..hypernetwork import Hypernetwork, load_hypernetwork
from . import sd_hijack_checkpoint
from .dataset import PersonalizedBase,PersonalizedDataLoader
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory,
training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method,
create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt,
preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed,
preview_width, preview_height,
use_beta_scheduler=False, beta_repeat_epoch=4000, min_lr=1e-7, gamma_rate=1):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
try:
beta_repeat_epoch = int(beta_repeat_epoch)
assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!"
min_lr = float(min_lr)
assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!"
gamma_rate = float(gamma_rate)
assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!"
except ValueError:
raise RuntimeError("Cannot use advanced LR scheduler settings!")
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
load_hypernetwork(hypernetwork_name)
assert shared.loaded_hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(shared.loaded_hypernetwork, Hypernetwork):
raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!")
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
os.makedirs(hypernetwork_dir, exist_ok=True)
else:
hypernetwork_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
pin_memory = shared.opts.pin_memory
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name, model=shared.sd_model,
cond_model=shared.sd_model.cond_stage_model,
device=devices.device, template_file=template_file,
include_cond=True, batch_size=batch_size,
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
tag_drop_out=tag_drop_out,
latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights(True)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
scheduler_beta = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=beta_repeat_epoch, T_mult=1, eta_min=min_lr)
scheduler_gamma = ExponentialLR(optimizer=optimizer, gamma=gamma_rate)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 # internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps - initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
if use_beta_scheduler:
scheduler_beta.step(hypernetwork.step)
scheduler_gamma.step(hypernetwork.step)
else:
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device,
non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
del c
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = hypernetwork.step + 1
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch,
{
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
hypernetwork.eval()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork.train()
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format,
processed.infotexts[0], p=p,
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
except Exception:
print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval()
# report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename

View File

@ -0,0 +1,22 @@
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
try:
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
except:
pass

View File

@ -0,0 +1,332 @@
import csv
import datetime
import html
import os
import sys
import traceback
import torch
import tqdm
from PIL import PngImagePlugin
from .dataset import PersonalizedBase, PersonalizedDataLoader
from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack
from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.textual_inversion import save_embedding
#apply OsError avoid here
delayed_values = {}
def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if step % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
try:
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
if write_csv_header:
csv_writer.writeheader()
if log_directory + filename in delayed_values:
delayed = delayed_values[log_directory + filename]
for step, epoch, epoch_step, values in delayed:
csv_writer.writerow({
"step": step,
"epoch": epoch,
"epoch_step": epoch_step,
**values,
})
delayed.clear()
epoch, epoch_step = divmod(step - 1, epoch_len)
csv_writer.writerow({
"step": step,
"epoch": epoch,
"epoch_step": epoch_step,
**values,
})
except OSError:
epoch, epoch_step = divmod(step-1, epoch_len)
if log_directory + filename in delayed_values:
delayed_values[log_directory + filename].append((step , epoch, epoch_step, values))
else:
delayed_values[log_directory + filename] = [(step, epoch, epoch_step, values)]
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps,
save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
assert template_file, "Prompt template file is empty"
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width,
training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every,
save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img,
preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale,
preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps,
save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
if create_image_every > 0 and save_image_with_stored_embedding:
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint()
initial_step = embedding.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
pin_memory = shared.opts.pin_memory
ds = PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=embedding_name, model=shared.sd_model,
cond_model=shared.sd_model.cond_stage_model,
device=devices.device, template_file=template_file,
batch_size=batch_size, gradient_step=gradient_step,
shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out,
latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method,
batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 # internal
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps - initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
scaler.step(optimizer)
scaler.update()
embedding.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = embedding.step + 1
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
# if shared.opts.save_optimizer_state:
# embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file,
remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format,
processed.infotexts[0], p=p,
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
if save_image_with_stored_embedding and os.path.exists(
last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format,
processed.infotexts[0], p=p,
forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename

175
patches/external_pr/ui.py Normal file
View File

@ -0,0 +1,175 @@
import html
import os
from modules import shared, sd_hijack, devices
from modules.paths import script_path
from modules.ui import create_refresh_button, gr_show
from webui import wrap_gradio_gpu_call
from .textual_inversion import train_embedding as train_embedding_external
from .hypernetwork import train_hypernetwork as train_hypernetwork_external
import gradio as gr
def train_hypernetwork_ui(*args):
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try:
sd_hijack.undo_optimizations()
hypernetwork, filename = train_hypernetwork_external(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
Hypernetwork saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
def on_train_gamma_tab(params=None):
with gr.Tab(label="Train_Gamma") as train_gamma:
gr.HTML(
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(
sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name,
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {
"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())},
"refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
"refresh_train_hypernetwork_name")
with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate',
placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
placeholder="Hypernetwork Learning rate", value="0.00001")
use_beta_scheduler_checkbox = gr.Checkbox(
label='Show advanced learn rate scheduler options(for Hypernetworks)')
with gr.Row(visible=False) as beta_scheduler_options:
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
placeholder="restricts decay value, but does not restrict gamma rate decay",
value="1e-7")
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
placeholder="Value should be in (0-1]", value="1")
use_beta_scheduler_checkbox.change(
fn=lambda show: gr_show(show),
inputs=[use_beta_scheduler_checkbox],
outputs=[beta_scheduler_options],
)
batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file',
value=os.path.join(script_path, "textual_inversion_templates",
"style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable',
value=500, precision=0)
save_embedding_every = gr.Number(
label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.",
value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once",
choices=['once', 'deterministic', 'random'])
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')
ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False)
ti_outcome = gr.HTML(elem_id="ti_error3", value="")
train_embedding.click(
fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
embedding_learn_rate,
batch_size,
gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
save_image_with_stored_embedding,
preview_from_txt2img,
*params.txt2img_preview_params,
],
outputs=[
ti_output,
ti_outcome,
]
)
train_hypernetwork.click(
fn=wrap_gradio_gpu_call(train_hypernetwork_ui, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
preview_from_txt2img,
*params.txt2img_preview_params,
use_beta_scheduler,
beta_repeat_epoch,
min_lr,
gamma_rate
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
return [(train_gamma, "Train Gamma", "train_gamma")]

View File

@ -180,6 +180,13 @@ class Hypernetwork:
for layer in layers: for layer in layers:
layer.eval() layer.eval()
def train(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
optimizer_saved_dict = {} optimizer_saved_dict = {}

View File

@ -7,10 +7,13 @@ import gradio as gr
from modules.paths import script_path from modules.paths import script_path
from modules.ui import create_refresh_button, gr_show from modules.ui import create_refresh_button, gr_show
import patches.ui as ui
import patches.textual_inversion as textual_inversion import patches.textual_inversion as textual_inversion
import patches.ui as ui
import patches.external_pr.ui as external_patch_ui
from webui import wrap_gradio_gpu_call from webui import wrap_gradio_gpu_call
setattr(shared.opts,'pin_memory', False)
def create_training_tab(params: script_callbacks.UiTrainTabParams = None): def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
with gr.Tab(label="Train_Beta") as train_beta: with gr.Tab(label="Train_Beta") as train_beta:
@ -149,7 +152,7 @@ def create_extension_tab(params=None):
script_callbacks.on_ui_train_tabs(create_training_tab) script_callbacks.on_ui_train_tabs(create_training_tab)
script_callbacks.on_ui_train_tabs(create_extension_tab) script_callbacks.on_ui_train_tabs(create_extension_tab)
script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_gamma_tab)
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):