#!/bin/env python # pylint: disable=no-member """ simple implementation of training api: `/sdapi/v1/train` - supports: create embedding, image preprocess, train embedding (with all known parameters) - does not (yet) support: create hyper-network, train hyper-network - compatible with progress api: `/sdapi/v1/progress` - if interrupted, auto-continues from last known step - create and preprocess executed as sync jobs - train is executed as async job with progress monitoring """ import argparse import asyncio import logging import math import os import sys import time import importlib from pathlib import Path, PurePath import filetype from PIL import Image sys.path.append(os.path.join(os.path.dirname(__file__), 'modules')) from modules.util import Map, log, set_logfile from modules.sdapi import close, get, interrupt, post, progress, session from modules.process import process_images from modules.grid import grid create_preview = importlib.import_module('modules.preview-embeddings').create_preview plot = importlib.import_module('modules.train-losschart').plot extract = importlib.import_module('modules.video-extract').extract gen_loss_rate_str = importlib.import_module('modules.train-lossrate').gen_loss_rate_str images = [] args = {} options = None cmdflags = None args = Map({ "training_model": "sd-v15-runwayml.ckpt", "extract_video": { "rate": 0, "fps": 5, "vstart": 0, "vend": 0 }, "create_embedding": { "name": "test", "num_vectors_per_token": 1, "overwrite_old": False, "init_text": "*" }, "preprocess": { "id_task": 0, "process_src": "", "process_dst": "", "process_width": 512, "process_height": 512, "process_flip": False, "process_split": False, "process_caption": True, "process_caption_deepbooru": False, "preprocess_txt_action": "ignore", "process_focal_crop": True, "process_focal_crop_face_weight": 0.9, "process_focal_crop_entropy_weight": 0.3, "process_focal_crop_edges_weight": 0.5, "process_focal_crop_debug": False, "split_threshold": 0.5, "overlap_ratio": 0.2, "process_multicrop": None, "process_multicrop_mindim": None, "process_multicrop_maxdim": None, "process_multicrop_minarea": None, "process_multicrop_maxarea": None, "process_multicrop_objective": None, "process_multicrop_threshold": None, }, "train_embedding": { "id_task": 0, "embedding_name": "", "learn_rate": -1, "batch_size": 1, "steps": 500, "data_root": "", "log_directory": "train/log", "template_filename": "subject_filewords.txt", "gradient_step": 20, "training_width": 512, "training_height": 512, "shuffle_tags": False, "tag_drop_out": 0, "clip_grad_mode": "disabled", "clip_grad_value": "0.1", "latent_sampling_method": "deterministic", "create_image_every": -1, "save_embedding_every": -1, "save_image_with_stored_embedding": False, "preview_from_txt2img": False, "preview_prompt": "", "preview_negative_prompt": "blurry, duplicate, ugly, deformed, low res, watermark, text", "preview_steps": 20, "preview_sampler_index": 0, "preview_cfg_scale": 6, "preview_seed": -1, "preview_width": 512, "preview_height": 512, "varsize": False, }, }) async def plotloss(params): logdir = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log')) try: plot(logdir, params.name) except Exception as err: log.warning({ 'loss chart error': err }) async def captions(docs: list): exclude = ['a', 'in', 'on', 'out', 'at', 'the', 'and', 'with', 'next', 'to', 'it', 'for', 'of', 'into', 'that'] d = dict() for f in docs: text = open(f, 'r', encoding='utf-8') for line in text: line = line.strip() line = line.lower() words = line.split(" ") for word in words: if word in exclude: continue d[word] = d[word] + 1 if word in d else 1 pairs = ((value, key) for (key,value) in d.items()) sort = sorted(pairs, reverse = True) if len(sort) > 10: del sort[10:] d = {k: v for v, k in sort} log.info({ 'top captions': d }) async def preprocess_cleanup(params): log.info({ 'preprocess cleanup': params.dst }) for f in Path(params.dst).glob('*.png'): f.unlink() for f in Path(params.dst).glob('*.jpg'): f.unlink() for f in Path(params.dst).glob('*.txt'): f.unlink() try: if os.path.isdir(params.dst): Path(params.dst).rmdir() except Exception as err: log.warning({ 'preprocess cleanup': params.dst, 'error': err }) async def preprocess_builtin(params): global images # pylint: disable=global-statement log.debug({ 'preprocess start' }) files = [os.path.join(params.src, f) for f in os.listdir(params.src) if os.path.isfile(os.path.join(params.src, f))] candidates = [f for f in files if filetype.is_image(f)] not_images = [f for f in files if (not filetype.is_image(f) and not f.endswith('.txt'))] images = [] low_res = [] for f in candidates: img = Image.open(f) mp = (img.size[0] * img.size[1]) / 1024 / 1024 if mp < 1 or img.size[0] < 512 or img.size[1] < 512: low_res.append(f) os.rename(f, f + '.skip') else: images.append(f) log.debug({ 'preprocess skipping': not_images }) log.debug({ 'preprocess low res': low_res }) args.preprocess.process_src = params.src args.preprocess.process_dst = params.dst log.debug({ 'preprocess args': args.preprocess }) _res = await post('/sdapi/v1/preprocess', json = args.preprocess) processed = [os.path.join(params.dst, f) for f in os.listdir(params.dst) if os.path.isfile(os.path.join(params.dst, f))] processed_imgs = [f for f in processed if f.endswith('.png')] processed_docs = [f for f in processed if f.endswith('.txt')] log.info({ 'preprocess': { 'source': params.src, 'destination': params.dst, 'files': len(files), 'images': len(images), 'processed': len(processed_imgs), 'captions': len(processed_docs), 'skipped': len(not_images), 'low-res': len(low_res) } }) if len(processed_docs) > 0: await captions(processed_docs) return len(processed_imgs) async def preprocess(params): global images # pylint: disable=global-statement res = 0 if os.path.isfile(params.src): if not filetype.is_video(params.src): kind = filetype.guess(params.src) log.error({ 'preprocess error': { 'not a valid movie file': params.src, 'guess': kind } }) else: extract_dst = os.path.join(params.dst, 'extract') log.debug({ 'preprocess args': args.extract_video }) images = extract(params.src, extract_dst, rate = args.extract_video.rate, fps = args.extract_video.fps, start = args.extract_video.vstart, end = args.extract_video.vend) # extract keyframes from movie if images > 0: params.src = extract_dst res = await preprocess(params) # call again but now with keyframes else: log.error({ 'preprocess video extract': 'no images' }) elif os.path.isdir(params.src): if params.overwrite: await preprocess_cleanup(params) elif os.path.isdir(params.dst): log.error({ 'preprocess output folder already exists': params.dst }) return 0 if params.preprocess == 'builtin': res = await preprocess_builtin(params) i = [os.path.join(params.dst, f) for f in os.listdir(params.dst) if os.path.isfile(os.path.join(params.dst, f)) and filetype.is_image(os.path.join(params.dst, f))] images = [Image.open(img) for img in i] res = len(images) elif params.preprocess == 'custom': t0 = time.perf_counter() args.preprocess.process_src = params.src args.preprocess.process_dst = params.dst process_images(src = params.src, dst = params.dst) i = [os.path.join(params.dst, f) for f in os.listdir(params.dst) if os.path.isfile(os.path.join(params.dst, f)) and filetype.is_image(os.path.join(params.dst, f))] images = [Image.open(img) for img in i] t1 = time.perf_counter() log.info({ 'preprocess': { 'source': params.src, 'destination': params.dst, 'images': len(images), 'time': round(t1 - t0, 2) } }) res = len(images) else: args.preprocess.process_dst = params.src i = [os.path.join(params.src, f) for f in os.listdir(params.src) if os.path.isfile(os.path.join(params.src, f)) and filetype.is_image(os.path.join(params.src, f))] images = [Image.open(img) for img in i] res = len(images) else: log.error({ 'preprocess error': { 'not a valid input': params.src } }) if len(images) > 0: img = grid(images, labels = None, width = 2048, height = 2048, border = 8, square = True) logdir = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log')) Path(logdir).mkdir(parents = True, exist_ok = True) fn = os.path.join(logdir, params.name + '.inputs.jpg') img.save(fn) log.info({ 'preprocess input grid': fn }) return res async def check(params): global options # pylint: disable=global-statement options = await get('/sdapi/v1/options') global cmdflags # pylint: disable=global-statement cmdflags = await get('/sdapi/v1/cmd-flags') logdir = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log', params.name)) logfile = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log', params.name + '.train.log')) set_logfile(logfile) log.info({ 'checking server options' }) options['training_xattention_optimizations'] = False options['training_image_repeats_per_epoch'] = 1 if params.skipmodel: log.info({ 'using model': options['sd_model_checkpoint'] }) else: log.debug({ 'check model': args.training_model }) if len(args.training_model) > 0 and not options['sd_model_checkpoint'].startswith(args.training_model): models = await get('/sdapi/v1/sd-models') models = [obj["title"] for obj in models] found = [i for i in models if i.startswith(args.training_model)] if len(found) == 0: log.error({ 'model not found': args.training_model, 'available': models }) exit() else: log.warning({ 'switching model': found[0] }) options['sd_model_checkpoint'] = found[0] log.debug({ 'check embedding': params.name }) lst = os.path.join(cmdflags['embeddings_dir']) log.debug({ 'embeddings folder': lst }) path = Path(cmdflags['embeddings_dir']).glob(f'{params.name}.pt*') matches = [f for f in path] for match in matches: if params.overwrite: log.info({ 'delete embedding': match.name }) os.remove(os.path.join(cmdflags['embeddings_dir'], match.name)) else: log.error({ 'embedding exists': match.name }) await close() exit() f = os.path.join(logdir, 'train.csv') if os.path.isfile(f): if params.overwrite: log.info({ 'delete training log': f }) os.remove(os.path.join(logdir, 'train.csv')) else: log.warning({ 'training log exists': f }) f = os.path.join(logdir, '..', params.name, '.png') if os.path.isfile(f): if params.overwrite: log.info({ 'delete training graph': f }) os.remove(f) log.debug({ 'options': 'update' }) await post('/sdapi/v1/options', options) return async def create(params): log.debug({ 'create start' }) if not os.path.isdir(args.preprocess.process_dst): log.error({ 'train source not found': args.preprocess.process_dst }) exit() if params.vectors == -1: # dynamically determine number of vectors depending on number of input images if len(images) <= 20: vectors = 2 elif len(images) <= 100: vectors = 4 else: vectors = 6 else: vectors = params.vectors if os.path.exists(params.name) and os.path.isfile(params.name): log.info({ 'deleting existing embedding': { 'name': params.name } }) os.remove(params.name) args.create_embedding.name = params.name words = params.init.split(',') if len(words) > vectors: params.init = ','.join(words[:vectors]) log.warning({ 'create embedding init words cut': params.init }) args.create_embedding.init_text = params.init args.create_embedding.num_vectors_per_token = vectors log.debug({ 'create args': args.create_embedding }) res = await post('/sdapi/v1/create/embedding', args.create_embedding) if 'info' in res: log.info({ 'create embedding': { 'name': params.name, 'init': params.init, 'vectors': vectors, 'message': res.info } }) else: log.error({ 'create failed:', res }) return None log.debug({ 'create end' }) return params.name async def train(params): log.debug({ 'train start' }) args.train_embedding.embedding_name = params.name imgs = [f for f in os.listdir(args.preprocess.process_dst) if os.path.isfile(os.path.join(args.preprocess.process_dst, f)) and filetype.is_image(os.path.join(args.preprocess.process_dst, f))] args.train_embedding.data_root = args.preprocess.process_dst if len(imgs) == 0: log.error({ 'train no input images in folder': args.preprocess.process_dst }) return if params.grad == -1: args.train_embedding.gradient_step = len(imgs) // args.train_embedding.batch_size divisor = args.train_embedding.gradient_step // 60 args.train_embedding.gradient_step = args.train_embedding.gradient_step // (1 + divisor) log.info({ 'dynamic gradient step': args.train_embedding.gradient_step }) if params.steps == -1: args.train_embedding.steps = params.maxsteps // args.train_embedding.gradient_step log.info({ 'dynamic steps': args.train_embedding.steps, 'estimated total steps': args.train_embedding.steps * args.train_embedding.gradient_step * args.train_embedding.batch_size }) epoch_size = args.train_embedding.batch_size * args.train_embedding.gradient_step if args.train_embedding.create_image_every == -1: args.train_embedding.create_image_every = args.train_embedding.steps // 10 if args.train_embedding.save_embedding_every == -1: args.train_embedding.save_embedding_every = args.train_embedding.steps // 10 if args.train_embedding.learn_rate == -1: loss_args = { "steps": args.train_embedding.steps, "step": args.train_embedding.create_image_every, "loss_start": params.rstart, "loss_end": params.rend, "loss_type": 'power', "power": params.rdescend } args.train_embedding.learn_rate = gen_loss_rate_str(**loss_args) log.info({ 'dynamic learn-rate': loss_args }) log.debug({ 'learn rate': args.train_embedding.learn_rate, 'params': loss_args }) log.info({ 'train embedding': { 'name': params.name, 'source': args.preprocess.process_dst, 'images': len(imgs), 'steps': args.train_embedding.steps, 'batch': args.train_embedding.batch_size, 'gradient-step': args.train_embedding.gradient_step, 'sampling': args.train_embedding.latent_sampling_method, 'epoch-size': epoch_size } }) log.info({ 'learn rate': args.train_embedding.learn_rate }) log.debug({ 'train args': args.train_embedding }) t0 = time.time() res = await post('/sdapi/v1/train/embedding', args.train_embedding) log.info({ 'train result': res }) t1 = time.time() log.info({ 'train embedding finished': { 'name': params.name, 'time': round(t1 - t0) } }) log.debug({ 'train end': res.info if 'info' in res else res }) return async def pipeline(params): log.debug({ 'pipeline start' }) # interrupt await interrupt() # preprocess num = await preprocess(params) if num == 0: log.warning({ 'preprocess': 'no resulting images'}) return # create embedding name = await create(params) if not params.name in name: log.error({ 'create embedding failed': name }) return # train embedding await train(params) await plotloss(params) # create_preview(params.name, params.init) log.debug({ 'pipeline end' }) return async def monitor(params): step = 0 t0 = time.perf_counter() t1 = time.perf_counter() log.info({' starting monitor': t0 }) finished = 0 while True: await asyncio.sleep(params.monitor) res = await progress() if not 'state' in res: log.info({ 'monitor disconnected': res }) break if (res.state.job_count == params.steps and res.state.job_no >= res.state.job_count) or (res.eta_relative < 0) or (res.interrupted) or (res.state.job_count == 0): # need exit case if interrupted or failed if res.interrupted: log.info({ 'monitor interrupted': { 'embedding': params.name } }) break # exit for monitor job else: finished += 1 if finished >= 2: # do it more than once since preprocessing job can finish just in time for monitor to finish log.info({ 'monitor finished': { 'embedding': params.name } }) break else: if res.state.job_no == 0: step = 0 t0 = time.perf_counter() t1 = time.perf_counter() try: if 'Loss:' in res.textinfo: text = res.textinfo.split('
')[0].split() loss = float(text[-1]) else: loss = -1 except: loss = -1 if math.isnan(loss): log.error({ 'monitor': { 'progress': round(100 * res.progress), 'embedding': params.name, 'eta': round(res.eta_relative), 'step': res.state.job_no, 'steps': res.state.job_count, 'loss': 'nan' } }) await interrupt() else: elapsed = t1 - t0 log.info({ 'monitor': { 'job': res.state.job, 'progress': round(100 * res.progress), 'embedding': params.name, 'epoch': (1 + res.state.job_no // len(images)) if len(images) > 0 else 'n/a', 'step': res.state.job_no, 'steps': res.state.job_count, 'loss': loss if loss > -1 else 'n/a', 'total': round(1.0 * elapsed * res.state.job_count / res.state.job_no) if res.state.job_no > 0 and t1 != t0 else 'n/a', 'elapsed': round(elapsed), 'remaining': round(res.eta_relative), 'it/s': round((res.state.job_no - step) / (time.perf_counter() - t1), 2) } }) if step % 10 == 0: await plotloss(params) step = res.state.job_no t1 = time.perf_counter() return async def main(): parser = argparse.ArgumentParser(description="sd train pipeline") parser.add_argument("--name", type = str, required = True, help = "embedding name, set to auto to use src folder name") parser.add_argument("--src", type = str, required = True, help = "source image folder or movie file") parser.add_argument("--init", type = str, default = "person", required = False, help = "initialization class, default: %(default)s") parser.add_argument("--dst", type = str, default = "/tmp", required = False, help = "destination image folder for processed images, default: %(default)s") parser.add_argument("--steps", type = int, default = -1, required = False, help = "training steps, default: %(default)s") parser.add_argument("--maxsteps", type = int, default = 5000, required = False, help = "max training steps used when dynamic gradient is active, default: %(default)s") parser.add_argument("--vectors", type = int, default = -1, required = False, help = "number of vectors per token, default: dynamic based on number of input images") parser.add_argument("--batch", type = int, default = 1, required = False, help = "batch size, default: %(default)s") parser.add_argument("--rate", type = str, default = "", required = False, help = "learn rate, default: dynamic") parser.add_argument("--rstart", type = float, default = 0.02, required = False, help = "starting learn rate if using dynamic rate, default: %(default)s") parser.add_argument("--rend", type = float, default = 0.0005, required = False, help = "ending learn rate if using dynamic rate, default: %(default)s") parser.add_argument("--rdescend", type = float, default = 2, required = False, help = "learn rate descend power when using dynamic rate, default: %(default)s") parser.add_argument("--grad", type = int, default = -1, required = False, help = "accumulate gradient over n images, default: : %(default)s") parser.add_argument("--type", type = str, default = 'subject', required = False, help = "training type: subject/style/unknown, default: %(default)s") parser.add_argument('--overwrite', default = False, action='store_true', help = "overwrite existing embedding, default: %(default)s") parser.add_argument("--vstart", type = float, default = 0, required = False, help = "if processing video skip first n seconds, default: %(default)s") parser.add_argument("--vend", type = float, default = 0, required = False, help = "if processing video skip last n seconds, default: %(default)s") parser.add_argument('--skipcaption', default = False, action='store_true', help = "do not auto-generate captions, default: %(default)s") parser.add_argument('--skipmodel', default = False, action='store_true', help = "skip model validation and switch, default: %(default)s") parser.add_argument('--preprocess', type = str, choices=['builtin', 'custom', 'none'], default = 'custom', help = "preprocessing type, default: %(default)s") parser.add_argument('--nocleanup', default = False, action='store_true', help = "skip cleanup after completion, default: %(default)s") parser.add_argument("--monitor", type = int, default = 30, required = False, help = "progress monitor frequency, default: : %(default)s") parser.add_argument('--debug', default = False, action='store_true', help = "print extra debug information, default: %(default)s") params = parser.parse_args() if params.debug: log.setLevel(logging.DEBUG) log.debug({ 'debug': True }) log.debug({ 'args': params.__dict__ }) home = Path(sys.argv[0]).parent global args # pylint: disable=global-statement if params.vstart > 0: args.extract_video.vstart = params.vstart if params.vend > 0: args.extract_video.vend = params.vend if params.steps > -1: args.train_embedding.steps = params.steps if params.batch > -1: args.train_embedding.batch_size = params.batch if params.rate != '': args.train_embedding.learn_rate = params.rate if params.grad > -1: args.train_embedding.gradient_step = params.grad if params.type == 'subject': if params.skipcaption: args.train_embedding.template_filename = 'subject.txt' args.preprocess.process_caption = False else: args.train_embedding.template_filename = 'subject_filewords.txt' elif params.type == 'style': if params.skipcaption: args.train_embedding.template_filename = 'style.txt' args.preprocess.process_caption = False else: args.train_embedding.template_filename = 'style_filewords.txt' else: if params.skipcaption: args.train_embedding.template_filename = 'unknown.txt' args.preprocess.process_caption = False else: args.train_embedding.template_filename = 'unknown_filewords.txt' if params.name == 'auto': params.name = PurePath(params.src).name log.info({ 'training name': params.name }) if params.dst == "/tmp": params.dst = os.path.join("/tmp/train", params.name) log.debug({ 'args': params.__dict__ }) params.src = os.path.abspath(params.src) params.dst = os.path.abspath(params.dst) try: await session() await check(params) a = asyncio.create_task(pipeline(params)) b = asyncio.create_task(monitor(params)) await asyncio.gather(a, b) # wait for both pipeline and monitor to finish except Exception as e: log.error({ 'exception': e }) finally: if not params.nocleanup: await preprocess_cleanup(params) await close() return if __name__ == "__main__": log.info({ 'train textual inversion' }) try: asyncio.run(main()) except KeyboardInterrupt: log.warning({ 'interrupted': 'keyboard request' }) # asyncio.run(interrupt())