automatic/cli/train-ti.py

592 lines
25 KiB
Python
Executable File

#!/bin/env python
# pylint: disable=no-member
"""
simple implementation of training api: `/sdapi/v1/train`
- supports: create embedding, image preprocess, train embedding (with all known parameters)
- does not (yet) support: create hyper-network, train hyper-network
- compatible with progress api: `/sdapi/v1/progress`
- if interrupted, auto-continues from last known step
- create and preprocess executed as sync jobs
- train is executed as async job with progress monitoring
"""
import argparse
import asyncio
import logging
import math
import os
import sys
import time
import importlib
from pathlib import Path, PurePath
import filetype
from PIL import Image
sys.path.append(os.path.join(os.path.dirname(__file__), 'modules'))
from modules.util import Map, log, set_logfile
from modules.sdapi import close, get, interrupt, post, progress, session
from modules.process import process_images
from modules.grid import grid
create_preview = importlib.import_module('modules.preview-embeddings').create_preview
plot = importlib.import_module('modules.train-losschart').plot
extract = importlib.import_module('modules.video-extract').extract
gen_loss_rate_str = importlib.import_module('modules.train-lossrate').gen_loss_rate_str
images = []
args = {}
options = None
cmdflags = None
args = Map({
"training_model": "sd-v15-runwayml.ckpt",
"extract_video": {
"rate": 0,
"fps": 5,
"vstart": 0,
"vend": 0
},
"create_embedding": {
"name": "test",
"num_vectors_per_token": 1,
"overwrite_old": False,
"init_text": "*"
},
"preprocess": {
"id_task": 0,
"process_src": "",
"process_dst": "",
"process_width": 512,
"process_height": 512,
"process_flip": False,
"process_split": False,
"process_caption": True,
"process_caption_deepbooru": False,
"preprocess_txt_action": "ignore",
"process_focal_crop": True,
"process_focal_crop_face_weight": 0.9,
"process_focal_crop_entropy_weight": 0.3,
"process_focal_crop_edges_weight": 0.5,
"process_focal_crop_debug": False,
"split_threshold": 0.5,
"overlap_ratio": 0.2,
"process_multicrop": None,
"process_multicrop_mindim": None,
"process_multicrop_maxdim": None,
"process_multicrop_minarea": None,
"process_multicrop_maxarea": None,
"process_multicrop_objective": None,
"process_multicrop_threshold": None,
},
"train_embedding": {
"id_task": 0,
"embedding_name": "",
"learn_rate": -1,
"batch_size": 1,
"steps": 500,
"data_root": "",
"log_directory": "train/log",
"template_filename": "subject_filewords.txt",
"gradient_step": 20,
"training_width": 512,
"training_height": 512,
"shuffle_tags": False,
"tag_drop_out": 0,
"clip_grad_mode": "disabled",
"clip_grad_value": "0.1",
"latent_sampling_method": "deterministic",
"create_image_every": -1,
"save_embedding_every": -1,
"save_image_with_stored_embedding": False,
"preview_from_txt2img": False,
"preview_prompt": "",
"preview_negative_prompt": "blurry, duplicate, ugly, deformed, low res, watermark, text",
"preview_steps": 20,
"preview_sampler_index": 0,
"preview_cfg_scale": 6,
"preview_seed": -1,
"preview_width": 512,
"preview_height": 512,
"varsize": False,
},
})
async def plotloss(params):
logdir = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log'))
try:
plot(logdir, params.name)
except Exception as err:
log.warning({ 'loss chart error': err })
async def captions(docs: list):
exclude = ['a', 'in', 'on', 'out', 'at', 'the', 'and', 'with', 'next', 'to', 'it', 'for', 'of', 'into', 'that']
d = dict()
for f in docs:
text = open(f, 'r', encoding='utf-8')
for line in text:
line = line.strip()
line = line.lower()
words = line.split(" ")
for word in words:
if word in exclude:
continue
d[word] = d[word] + 1 if word in d else 1
pairs = ((value, key) for (key,value) in d.items())
sort = sorted(pairs, reverse = True)
if len(sort) > 10:
del sort[10:]
d = {k: v for v, k in sort}
log.info({ 'top captions': d })
async def preprocess_cleanup(params):
log.info({ 'preprocess cleanup': params.dst })
for f in Path(params.dst).glob('*.png'):
f.unlink()
for f in Path(params.dst).glob('*.jpg'):
f.unlink()
for f in Path(params.dst).glob('*.txt'):
f.unlink()
try:
if os.path.isdir(params.dst):
Path(params.dst).rmdir()
except Exception as err:
log.warning({ 'preprocess cleanup': params.dst, 'error': err })
async def preprocess_builtin(params):
global images # pylint: disable=global-statement
log.debug({ 'preprocess start' })
files = [os.path.join(params.src, f) for f in os.listdir(params.src) if os.path.isfile(os.path.join(params.src, f))]
candidates = [f for f in files if filetype.is_image(f)]
not_images = [f for f in files if (not filetype.is_image(f) and not f.endswith('.txt'))]
images = []
low_res = []
for f in candidates:
img = Image.open(f)
mp = (img.size[0] * img.size[1]) / 1024 / 1024
if mp < 1 or img.size[0] < 512 or img.size[1] < 512:
low_res.append(f)
os.rename(f, f + '.skip')
else:
images.append(f)
log.debug({ 'preprocess skipping': not_images })
log.debug({ 'preprocess low res': low_res })
args.preprocess.process_src = params.src
args.preprocess.process_dst = params.dst
log.debug({ 'preprocess args': args.preprocess })
_res = await post('/sdapi/v1/preprocess', json = args.preprocess)
processed = [os.path.join(params.dst, f) for f in os.listdir(params.dst) if os.path.isfile(os.path.join(params.dst, f))]
processed_imgs = [f for f in processed if f.endswith('.png')]
processed_docs = [f for f in processed if f.endswith('.txt')]
log.info({ 'preprocess': {
'source': params.src,
'destination': params.dst,
'files': len(files),
'images': len(images),
'processed': len(processed_imgs),
'captions': len(processed_docs),
'skipped': len(not_images),
'low-res': len(low_res) }
})
if len(processed_docs) > 0:
await captions(processed_docs)
return len(processed_imgs)
async def preprocess(params):
global images # pylint: disable=global-statement
res = 0
if os.path.isfile(params.src):
if not filetype.is_video(params.src):
kind = filetype.guess(params.src)
log.error({ 'preprocess error': { 'not a valid movie file': params.src, 'guess': kind } })
else:
extract_dst = os.path.join(params.dst, 'extract')
log.debug({ 'preprocess args': args.extract_video })
images = extract(params.src, extract_dst, rate = args.extract_video.rate, fps = args.extract_video.fps, start = args.extract_video.vstart, end = args.extract_video.vend) # extract keyframes from movie
if images > 0:
params.src = extract_dst
res = await preprocess(params) # call again but now with keyframes
else:
log.error({ 'preprocess video extract': 'no images' })
elif os.path.isdir(params.src):
if params.overwrite:
await preprocess_cleanup(params)
elif os.path.isdir(params.dst):
log.error({ 'preprocess output folder already exists': params.dst })
return 0
if params.preprocess == 'builtin':
res = await preprocess_builtin(params)
i = [os.path.join(params.dst, f) for f in os.listdir(params.dst) if os.path.isfile(os.path.join(params.dst, f)) and filetype.is_image(os.path.join(params.dst, f))]
images = [Image.open(img) for img in i]
res = len(images)
elif params.preprocess == 'custom':
t0 = time.perf_counter()
args.preprocess.process_src = params.src
args.preprocess.process_dst = params.dst
process_images(src = params.src, dst = params.dst)
i = [os.path.join(params.dst, f) for f in os.listdir(params.dst) if os.path.isfile(os.path.join(params.dst, f)) and filetype.is_image(os.path.join(params.dst, f))]
images = [Image.open(img) for img in i]
t1 = time.perf_counter()
log.info({ 'preprocess': { 'source': params.src, 'destination': params.dst, 'images': len(images), 'time': round(t1 - t0, 2) } })
res = len(images)
else:
args.preprocess.process_dst = params.src
i = [os.path.join(params.src, f) for f in os.listdir(params.src) if os.path.isfile(os.path.join(params.src, f)) and filetype.is_image(os.path.join(params.src, f))]
images = [Image.open(img) for img in i]
res = len(images)
else:
log.error({ 'preprocess error': { 'not a valid input': params.src } })
if len(images) > 0:
img = grid(images, labels = None, width = 2048, height = 2048, border = 8, square = True)
logdir = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log'))
Path(logdir).mkdir(parents = True, exist_ok = True)
fn = os.path.join(logdir, params.name + '.inputs.jpg')
img.save(fn)
log.info({ 'preprocess input grid': fn })
return res
async def check(params):
global options # pylint: disable=global-statement
options = await get('/sdapi/v1/options')
global cmdflags # pylint: disable=global-statement
cmdflags = await get('/sdapi/v1/cmd-flags')
logdir = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log', params.name))
logfile = os.path.abspath(os.path.join(cmdflags['embeddings_dir'], '../train/log', params.name + '.train.log'))
set_logfile(logfile)
log.info({ 'checking server options' })
options['training_xattention_optimizations'] = False
options['training_image_repeats_per_epoch'] = 1
if params.skipmodel:
log.info({ 'using model': options['sd_model_checkpoint'] })
else:
log.debug({ 'check model': args.training_model })
if len(args.training_model) > 0 and not options['sd_model_checkpoint'].startswith(args.training_model):
models = await get('/sdapi/v1/sd-models')
models = [obj["title"] for obj in models]
found = [i for i in models if i.startswith(args.training_model)]
if len(found) == 0:
log.error({ 'model not found': args.training_model, 'available': models })
exit()
else:
log.warning({ 'switching model': found[0] })
options['sd_model_checkpoint'] = found[0]
log.debug({ 'check embedding': params.name })
lst = os.path.join(cmdflags['embeddings_dir'])
log.debug({ 'embeddings folder': lst })
path = Path(cmdflags['embeddings_dir']).glob(f'{params.name}.pt*')
matches = [f for f in path]
for match in matches:
if params.overwrite:
log.info({ 'delete embedding': match.name })
os.remove(os.path.join(cmdflags['embeddings_dir'], match.name))
else:
log.error({ 'embedding exists': match.name })
await close()
exit()
f = os.path.join(logdir, 'train.csv')
if os.path.isfile(f):
if params.overwrite:
log.info({ 'delete training log': f })
os.remove(os.path.join(logdir, 'train.csv'))
else:
log.warning({ 'training log exists': f })
f = os.path.join(logdir, '..', params.name, '.png')
if os.path.isfile(f):
if params.overwrite:
log.info({ 'delete training graph': f })
os.remove(f)
log.debug({ 'options': 'update' })
await post('/sdapi/v1/options', options)
return
async def create(params):
log.debug({ 'create start' })
if not os.path.isdir(args.preprocess.process_dst):
log.error({ 'train source not found': args.preprocess.process_dst })
exit()
if params.vectors == -1: # dynamically determine number of vectors depending on number of input images
if len(images) <= 20:
vectors = 2
elif len(images) <= 100:
vectors = 4
else:
vectors = 6
else:
vectors = params.vectors
if os.path.exists(params.name) and os.path.isfile(params.name):
log.info({ 'deleting existing embedding': { 'name': params.name } })
os.remove(params.name)
args.create_embedding.name = params.name
words = params.init.split(',')
if len(words) > vectors:
params.init = ','.join(words[:vectors])
log.warning({ 'create embedding init words cut': params.init })
args.create_embedding.init_text = params.init
args.create_embedding.num_vectors_per_token = vectors
log.debug({ 'create args': args.create_embedding })
res = await post('/sdapi/v1/create/embedding', args.create_embedding)
if 'info' in res:
log.info({ 'create embedding': { 'name': params.name, 'init': params.init, 'vectors': vectors, 'message': res.info } })
else:
log.error({ 'create failed:', res })
return None
log.debug({ 'create end' })
return params.name
async def train(params):
log.debug({ 'train start' })
args.train_embedding.embedding_name = params.name
imgs = [f for f in os.listdir(args.preprocess.process_dst) if os.path.isfile(os.path.join(args.preprocess.process_dst, f)) and filetype.is_image(os.path.join(args.preprocess.process_dst, f))]
args.train_embedding.data_root = args.preprocess.process_dst
if len(imgs) == 0:
log.error({ 'train no input images in folder': args.preprocess.process_dst })
return
if params.grad == -1:
args.train_embedding.gradient_step = len(imgs) // args.train_embedding.batch_size
divisor = args.train_embedding.gradient_step // 60
args.train_embedding.gradient_step = args.train_embedding.gradient_step // (1 + divisor)
log.info({ 'dynamic gradient step': args.train_embedding.gradient_step })
if params.steps == -1:
args.train_embedding.steps = params.maxsteps // args.train_embedding.gradient_step
log.info({ 'dynamic steps': args.train_embedding.steps, 'estimated total steps': args.train_embedding.steps * args.train_embedding.gradient_step * args.train_embedding.batch_size })
epoch_size = args.train_embedding.batch_size * args.train_embedding.gradient_step
if args.train_embedding.create_image_every == -1:
args.train_embedding.create_image_every = args.train_embedding.steps // 10
if args.train_embedding.save_embedding_every == -1:
args.train_embedding.save_embedding_every = args.train_embedding.steps // 10
if args.train_embedding.learn_rate == -1:
loss_args = {
"steps": args.train_embedding.steps,
"step": args.train_embedding.create_image_every,
"loss_start": params.rstart,
"loss_end": params.rend,
"loss_type": 'power',
"power": params.rdescend
}
args.train_embedding.learn_rate = gen_loss_rate_str(**loss_args)
log.info({ 'dynamic learn-rate': loss_args })
log.debug({ 'learn rate': args.train_embedding.learn_rate, 'params': loss_args })
log.info({ 'train embedding': {
'name': params.name,
'source': args.preprocess.process_dst,
'images': len(imgs),
'steps': args.train_embedding.steps,
'batch': args.train_embedding.batch_size,
'gradient-step': args.train_embedding.gradient_step,
'sampling': args.train_embedding.latent_sampling_method,
'epoch-size': epoch_size }
})
log.info({ 'learn rate': args.train_embedding.learn_rate })
log.debug({ 'train args': args.train_embedding })
t0 = time.time()
res = await post('/sdapi/v1/train/embedding', args.train_embedding)
log.info({ 'train result': res })
t1 = time.time()
log.info({ 'train embedding finished': { 'name': params.name, 'time': round(t1 - t0) } })
log.debug({ 'train end': res.info if 'info' in res else res })
return
async def pipeline(params):
log.debug({ 'pipeline start' })
# interrupt
await interrupt()
# preprocess
num = await preprocess(params)
if num == 0:
log.warning({ 'preprocess': 'no resulting images'})
return
# create embedding
name = await create(params)
if not params.name in name:
log.error({ 'create embedding failed': name })
return
# train embedding
await train(params)
await plotloss(params)
# create_preview(params.name, params.init)
log.debug({ 'pipeline end' })
return
async def monitor(params):
step = 0
t0 = time.perf_counter()
t1 = time.perf_counter()
log.info({' starting monitor': t0 })
finished = 0
while True:
await asyncio.sleep(params.monitor)
res = await progress()
if not 'state' in res:
log.info({ 'monitor disconnected': res })
break
if (res.state.job_count == params.steps and res.state.job_no >= res.state.job_count) or (res.eta_relative < 0) or (res.interrupted) or (res.state.job_count == 0): # need exit case if interrupted or failed
if res.interrupted:
log.info({ 'monitor interrupted': { 'embedding': params.name } })
break # exit for monitor job
else:
finished += 1
if finished >= 2: # do it more than once since preprocessing job can finish just in time for monitor to finish
log.info({ 'monitor finished': { 'embedding': params.name } })
break
else:
if res.state.job_no == 0:
step = 0
t0 = time.perf_counter()
t1 = time.perf_counter()
try:
if 'Loss:' in res.textinfo:
text = res.textinfo.split('<br/>')[0].split()
loss = float(text[-1])
else:
loss = -1
except:
loss = -1
if math.isnan(loss):
log.error({ 'monitor': { 'progress': round(100 * res.progress), 'embedding': params.name, 'eta': round(res.eta_relative), 'step': res.state.job_no, 'steps': res.state.job_count, 'loss': 'nan' } })
await interrupt()
else:
elapsed = t1 - t0
log.info({ 'monitor': {
'job': res.state.job,
'progress': round(100 * res.progress),
'embedding': params.name,
'epoch': (1 + res.state.job_no // len(images)) if len(images) > 0 else 'n/a',
'step': res.state.job_no,
'steps': res.state.job_count,
'loss': loss if loss > -1 else 'n/a',
'total': round(1.0 * elapsed * res.state.job_count / res.state.job_no) if res.state.job_no > 0 and t1 != t0 else 'n/a',
'elapsed': round(elapsed),
'remaining': round(res.eta_relative),
'it/s': round((res.state.job_no - step) / (time.perf_counter() - t1), 2) }
})
if step % 10 == 0:
await plotloss(params)
step = res.state.job_no
t1 = time.perf_counter()
return
async def main():
parser = argparse.ArgumentParser(description="sd train pipeline")
parser.add_argument("--name", type = str, required = True, help = "embedding name, set to auto to use src folder name")
parser.add_argument("--src", type = str, required = True, help = "source image folder or movie file")
parser.add_argument("--init", type = str, default = "person", required = False, help = "initialization class, default: %(default)s")
parser.add_argument("--dst", type = str, default = "/tmp", required = False, help = "destination image folder for processed images, default: %(default)s")
parser.add_argument("--steps", type = int, default = -1, required = False, help = "training steps, default: %(default)s")
parser.add_argument("--maxsteps", type = int, default = 5000, required = False, help = "max training steps used when dynamic gradient is active, default: %(default)s")
parser.add_argument("--vectors", type = int, default = -1, required = False, help = "number of vectors per token, default: dynamic based on number of input images")
parser.add_argument("--batch", type = int, default = 1, required = False, help = "batch size, default: %(default)s")
parser.add_argument("--rate", type = str, default = "", required = False, help = "learn rate, default: dynamic")
parser.add_argument("--rstart", type = float, default = 0.02, required = False, help = "starting learn rate if using dynamic rate, default: %(default)s")
parser.add_argument("--rend", type = float, default = 0.0005, required = False, help = "ending learn rate if using dynamic rate, default: %(default)s")
parser.add_argument("--rdescend", type = float, default = 2, required = False, help = "learn rate descend power when using dynamic rate, default: %(default)s")
parser.add_argument("--grad", type = int, default = -1, required = False, help = "accumulate gradient over n images, default: : %(default)s")
parser.add_argument("--type", type = str, default = 'subject', required = False, help = "training type: subject/style/unknown, default: %(default)s")
parser.add_argument('--overwrite', default = False, action='store_true', help = "overwrite existing embedding, default: %(default)s")
parser.add_argument("--vstart", type = float, default = 0, required = False, help = "if processing video skip first n seconds, default: %(default)s")
parser.add_argument("--vend", type = float, default = 0, required = False, help = "if processing video skip last n seconds, default: %(default)s")
parser.add_argument('--skipcaption', default = False, action='store_true', help = "do not auto-generate captions, default: %(default)s")
parser.add_argument('--skipmodel', default = False, action='store_true', help = "skip model validation and switch, default: %(default)s")
parser.add_argument('--preprocess', type = str, choices=['builtin', 'custom', 'none'], default = 'custom', help = "preprocessing type, default: %(default)s")
parser.add_argument('--nocleanup', default = False, action='store_true', help = "skip cleanup after completion, default: %(default)s")
parser.add_argument("--monitor", type = int, default = 30, required = False, help = "progress monitor frequency, default: : %(default)s")
parser.add_argument('--debug', default = False, action='store_true', help = "print extra debug information, default: %(default)s")
params = parser.parse_args()
if params.debug:
log.setLevel(logging.DEBUG)
log.debug({ 'debug': True })
log.debug({ 'args': params.__dict__ })
home = Path(sys.argv[0]).parent
global args # pylint: disable=global-statement
if params.vstart > 0:
args.extract_video.vstart = params.vstart
if params.vend > 0:
args.extract_video.vend = params.vend
if params.steps > -1:
args.train_embedding.steps = params.steps
if params.batch > -1:
args.train_embedding.batch_size = params.batch
if params.rate != '':
args.train_embedding.learn_rate = params.rate
if params.grad > -1:
args.train_embedding.gradient_step = params.grad
if params.type == 'subject':
if params.skipcaption:
args.train_embedding.template_filename = 'subject.txt'
args.preprocess.process_caption = False
else:
args.train_embedding.template_filename = 'subject_filewords.txt'
elif params.type == 'style':
if params.skipcaption:
args.train_embedding.template_filename = 'style.txt'
args.preprocess.process_caption = False
else:
args.train_embedding.template_filename = 'style_filewords.txt'
else:
if params.skipcaption:
args.train_embedding.template_filename = 'unknown.txt'
args.preprocess.process_caption = False
else:
args.train_embedding.template_filename = 'unknown_filewords.txt'
if params.name == 'auto':
params.name = PurePath(params.src).name
log.info({ 'training name': params.name })
if params.dst == "/tmp":
params.dst = os.path.join("/tmp/train", params.name)
log.debug({ 'args': params.__dict__ })
params.src = os.path.abspath(params.src)
params.dst = os.path.abspath(params.dst)
try:
await session()
await check(params)
a = asyncio.create_task(pipeline(params))
b = asyncio.create_task(monitor(params))
await asyncio.gather(a, b) # wait for both pipeline and monitor to finish
except Exception as e:
log.error({ 'exception': e })
finally:
if not params.nocleanup:
await preprocess_cleanup(params)
await close()
return
if __name__ == "__main__":
log.info({ 'train textual inversion' })
try:
asyncio.run(main())
except KeyboardInterrupt:
log.warning({ 'interrupted': 'keyboard request' })
# asyncio.run(interrupt())