diff --git a/.pylintrc b/.pylintrc index 0fb6ba8dc..b2c67d5b2 100644 --- a/.pylintrc +++ b/.pylintrc @@ -33,7 +33,6 @@ ignore-paths=/usr/lib/.*$, modules/taesd, modules/teacache, modules/todo, - modules/unipc, pipelines/flex2, pipelines/hidream, pipelines/meissonic, diff --git a/modules/devices.py b/modules/devices.py index 8d15fd238..83ec573e4 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -159,13 +159,6 @@ def get_gpu_info(): return { 'error': ex } -def extract_device_id(args, name): # pylint: disable=redefined-outer-name - for x in range(len(args)): - if name in args[x]: - return args[x + 1] - return None - - def get_cuda_device_string(): from modules.shared import cmd_opts if backend == 'ipex': @@ -196,13 +189,6 @@ def get_optimal_device(): return torch.device(get_optimal_device_name()) -def get_device_for(task): # pylint: disable=unused-argument - # if task in cmd_opts.use_cpu: - # log.debug(f'Forcing CPU for task: {task}') - # return cpu - return get_optimal_device() - - def torch_gc(force:bool=False, fast:bool=False, reason:str=None): def get_stats(): mem_dict = memstats.memory_stats() @@ -583,14 +569,6 @@ def set_cuda_params(): log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"') -def cond_cast_unet(tensor): - return tensor.to(dtype_unet) if unet_needs_upcast else tensor - - -def cond_cast_float(tensor): - return tensor.float() if unet_needs_upcast else tensor - - def randn(seed, shape=None): torch.manual_seed(seed) if backend == 'ipex': diff --git a/modules/errors.py b/modules/errors.py index b505a1d0e..81cfe9379 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -15,12 +15,6 @@ def install(suppress=[]): logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s') -def print_error_explanation(message): - lines = message.strip().split("\n") - for line in lines: - log.error(line) - - def display(e: Exception, task: str, suppress=[]): log.error(f"{task or 'error'}: {type(e).__name__}") console = get_console() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 6393c9051..75893417e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -138,15 +138,6 @@ def should_skip(param): return skip -def bind_buttons(buttons, image_component, send_generate_info): - """old function for backwards compatibility; do not use this, use register_paste_params_button""" - for tabname, button in buttons.items(): - source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None - source_tabname = send_generate_info if isinstance(send_generate_info, str) else None - bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=image_component, source_tabname=source_tabname) - register_paste_params_button(bindings) - - def register_paste_params_button(binding: ParamBinding): registered_param_bindings.append(binding) @@ -194,17 +185,6 @@ def send_image(x): return image -def send_image_and_dimensions(x): - image = x if isinstance(x, Image.Image) else image_from_url_text(x) - if shared.opts.send_size and isinstance(image, Image.Image): - w = image.width - h = image.height - else: - w = gr.update() - h = gr.update() - return image, w, h - - def create_override_settings_dict(text_pairs): res = {} params = {} diff --git a/modules/infotext.py b/modules/infotext.py index 50d443c19..6eef31897 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -28,40 +28,6 @@ def unquote(text): return text -# disabled by default can be enabled if needed -def check_lora(params): - try: - from modules.lora import lora_load - from modules.errors import log # pylint: disable=redefined-outer-name - except Exception: - return - loras = [s.strip() for s in params.get('LoRA hashes', '').split(',')] - found = [] - missing = [] - for l in loras: - lora = lora_load.available_network_hash_lookup.get(l, None) - if lora is not None: - found.append(lora.name) - else: - missing.append(l) - loras = [s.strip() for s in params.get('LoRA networks', '').split(',')] - for l in loras: - lora = lora_load.available_network_aliases.get(l, None) - if lora is not None: - found.append(lora.name) - else: - missing.append(l) - # networks.available_network_aliases.get(name, None) - loras = re_lora.findall(params.get('Prompt', '')) - for l in loras: - lora = lora_load.available_network_aliases.get(l, None) - if lora is not None: - found.append(lora.name) - else: - missing.append(l) - log.debug(f'LoRA: found={list(set(found))} missing={list(set(missing))}') - - def parse(infotext): if not isinstance(infotext, str): return {} @@ -115,7 +81,6 @@ def parse(infotext): params[key] = val debug(f'Param parsed: type={type(params[key])} "{key}"={params[key]} raw="{val}"') - # check_lora(params) return params diff --git a/modules/memstats.py b/modules/memstats.py index f226f3c83..5af35d6be 100644 --- a/modules/memstats.py +++ b/modules/memstats.py @@ -84,10 +84,6 @@ def reset_stats(): pass -def memory_cache(): - return mem - - def ram_stats(): try: process = psutil.Process(os.getpid()) diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py index 008da4144..162f9db22 100644 --- a/modules/sd_checkpoint.py +++ b/modules/sd_checkpoint.py @@ -160,11 +160,8 @@ def list_models(): def update_model_hashes(): txt = [] lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None] - # shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models') for ckpt in lst: ckpt.hash = model_hash(ckpt.filename) - # txt.append(f'Calculated short hash: {ckpt.title} {ckpt.hash}') - # txt.append(f'Updated short hashes for {len(lst)} out of {len(checkpoints_list)} models') lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None] shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}') for ckpt in lst: @@ -214,15 +211,6 @@ def get_closet_checkpoint_match(s: str) -> CheckpointInfo: checkpoint_info = CheckpointInfo(s) return checkpoint_info - # reference search - """ - found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path'])) - if found and len(found) == 1: - checkpoint_info = CheckpointInfo(found[0]['path']) # create a virutal model info - checkpoint_info.type = 'huggingface' - return checkpoint_info - """ - # huggingface search if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1: modelloader.hf_login() @@ -251,13 +239,10 @@ def model_hash(filename): try: with open(filename, "rb") as file: import hashlib - # t0 = time.time() m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) shorthash = m.hexdigest()[0:8] - # t1 = time.time() - # shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}') return shorthash except FileNotFoundError: return 'NOFILE' @@ -280,14 +265,11 @@ def select_checkpoint(op='model'): shared.log.info(" or use --ckpt-dir to specify folder with sd models") shared.log.info(" or use --ckpt to force using specific model") return None - # checkpoint_info = next(iter(checkpoints_list.values())) if model_checkpoint is not None: if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0': shared.log.info(f'Load {op}: search="{model_checkpoint}" not found') else: shared.log.info("Selecting first available checkpoint") - # shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}") - # shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title else: shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"') return checkpoint_info @@ -367,8 +349,6 @@ def read_metadata_from_safetensors(filename): t1 = time.time() global sd_metadata_timer # pylint: disable=global-statement sd_metadata_timer += (t1 - t0) - # except Exception as e: - # shared.log.error(f"Error reading metadata from: {filename} {e}") return res diff --git a/modules/sd_models.py b/modules/sd_models.py index cf57e4c5d..011d75c6f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -620,7 +620,7 @@ def load_diffuser(checkpoint_info=None, timer=None, op='model', revision=None): if debug_load: shared.log.trace(f'Model components: {list(get_signature(sd_model).values())}') - from modules.textual_inversion import textual_inversion + from modules import textual_inversion sd_model.embedding_db = textual_inversion.EmbeddingDatabase() sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir) sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True) diff --git a/modules/sd_samplers_diffusers.py b/modules/sd_samplers_diffusers.py index 8df2d0dbd..b1f393413 100644 --- a/modules/sd_samplers_diffusers.py +++ b/modules/sd_samplers_diffusers.py @@ -319,9 +319,6 @@ class DiffusionSampler: return self.sampler = sampler - if name == 'DC Solver': - if not hasattr(self.sampler, 'dc_ratios'): - pass # shared.log.debug_log(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}') self.sampler.name = name diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 88eebcdb0..b6c3982e9 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,5 @@ import os import glob -from copy import deepcopy import torch from modules import shared, errors, paths, devices, sd_models, sd_detect @@ -12,35 +11,13 @@ loaded_vae_file = None checkpoint_info = None vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE')) debug = os.environ.get('SD_LOAD_DEBUG', None) is not None +unspecified = object() -def get_base_vae(model): - if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: - return base_vae - return None - - -def store_base_vae(model): - global base_vae, checkpoint_info # pylint: disable=global-statement - if checkpoint_info != model.sd_checkpoint_info: - assert not loaded_vae_file, "Trying to store non-base VAE!" - base_vae = deepcopy(model.first_stage_model.state_dict()) - checkpoint_info = model.sd_checkpoint_info - - -def delete_base_vae(): - global base_vae, checkpoint_info # pylint: disable=global-statement - base_vae = None - checkpoint_info = None - - -def restore_base_vae(model): - global loaded_vae_file # pylint: disable=global-statement - if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: - shared.log.info("Restoring base VAE") - _load_vae_dict(model, base_vae) - loaded_vae_file = None - delete_base_vae() +def load_vae_dict(filename): + vae_ckpt = sd_models.read_state_dict(filename, what='vae') + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 def get_filename(filepath): @@ -114,35 +91,6 @@ def resolve_vae(checkpoint_file): return None, None -def load_vae_dict(filename): - vae_ckpt = sd_models.read_state_dict(filename, what='vae') - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} - return vae_dict_1 - - -def load_vae(model, vae_file=None, vae_source="unknown-source"): - global loaded_vae_file # pylint: disable=global-statement - if vae_file: - try: - if not os.path.isfile(vae_file): - shared.log.error(f"VAE not found: model={vae_file} source={vae_source}") - return - store_base_vae(model) - vae_dict_1 = load_vae_dict(vae_file) - _load_vae_dict(model, vae_dict_1) - except Exception as e: - shared.log.error(f"Load VAE failed: model={vae_file} source={vae_source} {e}") - if debug: - errors.display(e, 'VAE') - restore_base_vae(model) - vae_opt = get_filename(vae_file) - if vae_opt not in vae_dict: - vae_dict[vae_opt] = vae_file - elif loaded_vae_file: - restore_base_vae(model) - loaded_vae_file = vae_file - - def apply_vae_config(model_file, vae_file, sd_model): def get_vae_config(): config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(model_file))[0] + '_vae.json') @@ -219,20 +167,6 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"): return None -# don't call this from outside -def _load_vae_dict(model, vae_dict_1): - model.first_stage_model.load_state_dict(vae_dict_1) - model.first_stage_model.to(devices.dtype_vae) - - -def clear_loaded_vae(): - global loaded_vae_file # pylint: disable=global-statement - loaded_vae_file = None - - -unspecified = object() - - def reload_vae_weights(sd_model=None, vae_file=unspecified): if not sd_model: sd_model = shared.sd_model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion.py similarity index 86% rename from modules/textual_inversion/textual_inversion.py rename to modules/textual_inversion.py index c24f6cbfd..f6b1c558d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion.py @@ -3,9 +3,7 @@ import os import time import torch import safetensors.torch -from PIL import Image from modules import shared, devices, errors -from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed from modules.files_cache import directory_files, directory_mtime, extension_filter @@ -241,9 +239,6 @@ class EmbeddingDatabase: def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) - def clear_embedding_dirs(self): - self.embedding_dirs.clear() - def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding if hasattr(model, 'cond_stage_model'): @@ -306,45 +301,6 @@ class EmbeddingDatabase: errors.display(e, f'Load embedding: name="{embedding.name}" file="{embedding.filename}"') return - def load_from_file(self, path, filename): - name, ext = os.path.splitext(filename) - ext = ext.upper() - - if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: - if '.preview' in filename.lower(): - return - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - else: - data = extract_image_data_embed(embed_image) - if not data: # if data is None, means this is not an embeding, just a preview image - return - elif ext in ['.BIN', '.PT']: - data = torch.load(path, map_location="cpu") - elif ext in ['.SAFETENSORS']: - data = safetensors.torch.load_file(path, device="cpu") - else: - return - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - if len(data.keys()) != 1: - self.skipped_embeddings[name] = Embedding(None, name=name, filename=path) - return - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - else: - raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding") - - def load_from_dir(self, embdir): if not shared.sd_loaded: shared.log.info('Skipping embeddings load: model not loaded') @@ -387,14 +343,3 @@ class EmbeddingDatabase: self.previously_displayed_embeddings = displayed_embeddings t1 = time.time() shared.log.info(f"Network load: type=embeddings loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}") - - - def find_embedding_at_position(self, tokens, offset): - token = tokens[offset] - possible_matches = self.ids_lookup.get(token, None) - if possible_matches is None: - return None, None - for ids, embedding in possible_matches: - if tokens[offset:offset + len(ids)] == ids: - return embedding, len(ids) - return None, None diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py deleted file mode 100644 index 74eb88fd2..000000000 --- a/modules/textual_inversion/image_embedding.py +++ /dev/null @@ -1,213 +0,0 @@ -import base64 -import json -import zlib -import numpy as np -from PIL import Image, ImageDraw, ImageFont -import torch -from modules.shared import opts - - -class EmbeddingEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, torch.Tensor): - return {'TORCHTENSOR': o.cpu().detach().numpy().tolist()} - return json.JSONEncoder.default(self, o) - - -class EmbeddingDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs) - - def object_hook(self, d): # pylint: disable=E0202 - if 'TORCHTENSOR' in d: - return torch.from_numpy(np.array(d['TORCHTENSOR'])) - return d - - -def embedding_to_b64(data): - d = json.dumps(data, cls=EmbeddingEncoder) - return base64.b64encode(d.encode()) - - -def embedding_from_b64(data): - d = base64.b64decode(data) - return json.loads(d, cls=EmbeddingDecoder) - - -def lcg(m=2**32, a=1664525, c=1013904223, seed=0): - while True: - seed = (a * seed + c) % m - yield seed % 255 - - -def xor_block(block): - blk = lcg() - randblock = np.array([next(blk) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape) - return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F) - - -def style_block(block, sequence): - im = Image.new('RGB', (block.shape[1], block.shape[0])) - draw = ImageDraw.Draw(im) - i = 0 - for x in range(-6, im.size[0], 8): - for yi, y in enumerate(range(-6, im.size[1], 8)): - offset = 0 - if yi % 2 == 0: - offset = 4 - shade = sequence[i % len(sequence)] - i += 1 - draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade)) - - fg = np.array(im).astype(np.uint8) & 0xF0 - - return block ^ fg - - -def insert_image_data_embed(image, data): - d = 3 - data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9) - data_np_ = np.frombuffer(data_compressed, np.uint8).copy() - data_np_high = data_np_ >> 4 - data_np_low = data_np_ & 0x0F - - h = image.size[1] - next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) - next_size = next_size + ((h*d)-(next_size % (h*d))) - - data_np_low = np.resize(data_np_low, next_size) - data_np_low = data_np_low.reshape((h, -1, d)) - - data_np_high = np.resize(data_np_high, next_size) - data_np_high = data_np_high.reshape((h, -1, d)) - - edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] - edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8) - - data_np_low = style_block(data_np_low, sequence=edge_style) - data_np_low = xor_block(data_np_low) - data_np_high = style_block(data_np_high, sequence=edge_style[::-1]) - data_np_high = xor_block(data_np_high) - - im_low = Image.fromarray(data_np_low, mode='RGB') - im_high = Image.fromarray(data_np_high, mode='RGB') - - background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0)) - background.paste(im_low, (0, 0)) - background.paste(image, (im_low.size[0]+1, 0)) - background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0)) - - return background - - -def crop_black(img, tol=0): - mask = (img > tol).all(2) - mask0, mask1 = mask.any(0), mask.any(1) - col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax() - row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax() - return img[row_start:row_end, col_start:col_end] - - -def extract_image_data_embed(image): - d = 3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F # pylint: disable=E1121 - black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0) - if black_cols[0].shape[0] < 2: - return None - - data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8) - data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8) - - data_block_lower = xor_block(data_block_lower) - data_block_upper = xor_block(data_block_upper) - - data_block = (data_block_upper << 4) | (data_block_lower) - data_block = data_block.flatten().tobytes() - - data = zlib.decompress(data_block) - return json.loads(data, cls=EmbeddingDecoder) - - -def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None): - from math import cos - image = srcimage.copy() - fontsize = 32 - if textfont is None: - textfont = opts.font or 'javascript/notosans-nerdfont-regular.ttf' - - factor = 1.5 - gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0)) - for y in range(image.size[1]): - mag = 1-cos(y/image.size[1]*factor) - mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) - gradient.putpixel((0, y), (0, 0, 0, int(mag*255))) - image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) - - draw = ImageDraw.Draw(image) - - font = ImageFont.truetype(textfont, fontsize) - padding = 10 - - _, _, w, _h = draw.textbbox((0, 0), title, font=font) - fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72) - font = ImageFont.truetype(textfont, fontsize) - _, _, w, _h = draw.textbbox((0, 0), title, font=font) - draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230)) - - _, _, w, _h = draw.textbbox((0, 0), footerLeft, font=font) - fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) - _, _, w, _h = draw.textbbox((0, 0), footerMid, font=font) - fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) - _, _, w, _h = draw.textbbox((0, 0), footerRight, font=font) - fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) - - font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right)) - - draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230)) - draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230)) - draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230)) - - return image - - -if __name__ == '__main__': - - testEmbed = Image.open('test_embedding.png') - test_data = extract_image_data_embed(testEmbed) - assert test_data is not None - - test_data = embedding_from_b64(testEmbed.text['sd-ti-embedding']) - assert test_data is not None - - new_image = Image.new('RGBA', (512, 512), (255, 255, 200, 255)) - cap_image = caption_image_overlay(new_image, 'title', 'footerLeft', 'footerMid', 'footerRight') - - test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}} - - embedded_image = insert_image_data_embed(cap_image, test_embed) - - retrived_embed = extract_image_data_embed(embedded_image) - - assert str(retrived_embed) == str(test_embed) - - embedded_image2 = insert_image_data_embed(cap_image, retrived_embed) - - assert embedded_image == embedded_image2 - - g = lcg() - shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist() - - reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, - 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, - 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, - 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, - 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, - 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, - 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, - 204, 86, 73, 222, 44, 198, 118, 240, 97] - - assert shared_random == reference_random - - hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) - - assert 12731374 == hunna_kay_random_sum diff --git a/modules/ui.py b/modules/ui.py index be52a04b0..f3f942636 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -67,10 +67,6 @@ def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument pass -def ordered_ui_categories(): - return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented - - def create_ui(startup_timer = None): if startup_timer is None: timer.startup = timer.Timer() diff --git a/modules/ui_common.py b/modules/ui_common.py index 762b99701..44a4c816c 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -323,19 +323,6 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args = No return refresh_button -def create_browse_button(browse_component, elem_id): - def browse(folder): - # import subprocess - if folder is not None: - return gr.update(value = folder) - return gr.update() - - browse_button = ui_components.ToolButton(value=ui_symbols.folder, elem_id=elem_id) - browse_button.click(fn=browse, _js="async () => await browseFolder()", inputs=[browse_component], outputs=[browse_component]) - # browse_button.click(fn=browse, inputs=[browse_component], outputs=[browse_component]) - return browse_button - - def create_override_inputs(tab): # pylint: disable=unused-argument with gr.Row(elem_id=f"{tab}_override_settings_row"): override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 378f6a0db..bc00a82ab 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -69,10 +69,6 @@ def list_extensions(): debug(f'Extension installed without index: {entry}') -def check_access(): - assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags" - - def apply_changes(disable_list, update_list, disable_all): if shared.cmd_opts.disable_extension_access: shared.log.error('Extension: apply changes disallowed because public access is enabled and insecure is not specified') @@ -126,18 +122,6 @@ def check_updates(_id_task, disable_list, search_text, sort_column): return create_html(search_text, sort_column), "Extension update complete | Restart required" -def make_commit_link(commit_hash, remote, text=None): - if text is None: - text = commit_hash[:8] - if remote.startswith("https://github.com/"): - if remote.endswith(".git"): - remote = remote[:-4] - href = remote + "/commit/" + commit_hash - return f'{text}' - else: - return text - - def normalize_git_url(url): if url is None: return "" diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 43837ce2d..eef1c3a52 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -162,9 +162,6 @@ class ExtraNetworksPage: preview = f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" return preview - def is_empty(self, folder): - return any(files_cache.list_files(folder, ext_filter=['.ckpt', '.safetensors', '.pt', '.json'])) - def create_thumb(self): debug(f'EN create-thumb: {self.name}') created = 0 diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index ea8067fea..55b5b01d1 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ import json import os from modules import shared, sd_models, ui_extra_networks, files_cache -from modules.textual_inversion.textual_inversion import Embedding +from modules.textual_inversion import Embedding class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): diff --git a/modules/ui_javascript.py b/modules/ui_javascript.py index 7bbe59988..6668d33c5 100644 --- a/modules/ui_javascript.py +++ b/modules/ui_javascript.py @@ -115,20 +115,5 @@ def reload_javascript(): gradio.routes.templates.TemplateResponse = template_response -def setup_ui_api(app): - from pydantic import BaseModel, Field # pylint: disable=no-name-in-module - from typing import List - - class QuicksettingsHint(BaseModel): # pylint: disable=too-few-public-methods - name: str = Field(title="Name of the quicksettings field") - label: str = Field(title="Label of the quicksettings field") - - def quicksettings_hint(): - return [QuicksettingsHint(name=k, label=v.label) for k, v in shared.opts.data_labels.items()] - - app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) - app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) - - if not hasattr(shared, 'GradioTemplateResponseOriginal'): shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse diff --git a/modules/ui_sections.py b/modules/ui_sections.py index 7edb31a1c..a0a1b07bd 100644 --- a/modules/ui_sections.py +++ b/modules/ui_sections.py @@ -99,20 +99,6 @@ def create_interrogate_button(tab: str, inputs: list = None, outputs: str = None return button_interrogate -def create_interrogate_buttons(tab): # legacy function - button_interrogate = gr.Button(ui_symbols.int_clip, elem_id=f"{tab}_interrogate", elem_classes=['interrogate-clip']) - button_deepbooru = gr.Button(ui_symbols.int_blip, elem_id=f"{tab}_deepbooru", elem_classes=['interrogate-blip']) - return button_interrogate, button_deepbooru - - -def create_sampler_inputs(tab, accordion=True): - with gr.Accordion(open=False, label="Sampler", elem_id=f"{tab}_sampler", elem_classes=["small-accordion"]) if accordion else gr.Group(): - with gr.Row(elem_id=f"{tab}_row_sampler"): - sd_samplers.set_samplers() - steps, sampler_index = create_sampler_and_steps_selection(sd_samplers.samplers, tab) - return steps, sampler_index - - def create_batch_inputs(tab, accordion=True): with gr.Accordion(open=False, label="Batch", elem_id=f"{tab}_batch", elem_classes=["small-accordion"]) if accordion else gr.Group(): with gr.Row(elem_id=f"{tab}_row_batch"): diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 7a7b75430..b53f7bb0e 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -48,10 +48,6 @@ def get_value_for_setting(key): return gr.update(value=value, **args) -def ordered_ui_categories(): - return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented - - def create_setting_component(key, is_quicksettings=False): def fun(): return shared.opts.data[key] if key in shared.opts.data else shared.opts.data_labels[key].default @@ -88,7 +84,6 @@ def create_setting_component(key, is_quicksettings=False): elif info.folder is not None: with gr.Row(): res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args) - # ui_common.create_browse_button(res, f"folder_{key}") else: try: res = comp(label=info.label, value=fun(), elem_id=elem_id, **args) diff --git a/modules/ui_txt2img.py b/modules/ui_txt2img.py index a5358dd1f..ecc0678ae 100644 --- a/modules/ui_txt2img.py +++ b/modules/ui_txt2img.py @@ -1,19 +1,9 @@ import gradio as gr from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call -from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing, processing_vae, devices, images +from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing_vae, images from modules.ui_components import ToolButton # pylint: disable=unused-import -def calc_resolution_hires(width, height, hr_scale, hr_resize_x, hr_resize_y, hr_upscaler): - if hr_upscaler == "None": - return "Hires resize: None" - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - p.init_hr() - with devices.autocast(): - p.init([""], [0], [0]) - return f"Hires resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - def create_ui(): shared.log.debug('UI initialize: txt2img') import modules.txt2img # pylint: disable=redefined-outer-name diff --git a/modules/unipc/__init__.py b/modules/unipc/__init__.py deleted file mode 100644 index e1265e3fe..000000000 --- a/modules/unipc/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sampler import UniPCSampler diff --git a/modules/unipc/sampler.py b/modules/unipc/sampler.py deleted file mode 100644 index b5e116d61..000000000 --- a/modules/unipc/sampler.py +++ /dev/null @@ -1,191 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch - -from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC, get_time_steps -from modules import shared, devices - - -class UniPCSampler(object): - def __init__(self, model, **kwargs): - super().__init__() - self.model = model - to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.before_sample = None - self.after_sample = None - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) - - self.noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - # persist steps so we can eventually find denoising strength - self.inflated_steps = ddim_num_steps - - @devices.inference_context() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - if noise is None: - noise = torch.randn_like(x0) - - # first time we have all the info to get the real parameters from the ui - # value from the hires steps slider: - num_inference_steps = t[0] + 1 - num_inference_steps / self.inflated_steps - self.denoise_steps = max(num_inference_steps, shared.opts.schedulers_solver_order) - - max(self.inflated_steps - self.denoise_steps, 0) - - # actual number of steps we'll run - - all_timesteps = get_time_steps( - self.noise_schedule, - shared.opts.uni_pc_skip_type, - self.noise_schedule.T, - 1./self.noise_schedule.total_N, - self.inflated_steps+1, - t.device, - ) - - # the rest of the timesteps will be used for denoising - self.timesteps = all_timesteps[-(self.denoise_steps+1):] - - latent_timestep = ( - ( # get the timestep of our first denoise step - self.timesteps[:1] - # multiply by number of alphas to get int index - * self.noise_schedule.total_N - ).int() - 1 # minus one for 0-indexed - ).repeat(x0.shape[0]) - - alphas_cumprod = self.alphas_cumprod - sqrt_alpha_prod = alphas_cumprod[latent_timestep] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(x0.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[latent_timestep]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(x0.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - return (sqrt_alpha_prod * x0 + sqrt_one_minus_alpha_prod * noise) - - def decode(self, x_latent, conditioning, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - # same as in .sample(), i guess - model_type = "v" if self.model.parameterization == "v" else "noise" - - model_fn = model_wrapper( - lambda x, t, c: self.model.apply_model(x, t, c), - self.noise_schedule, - model_type=model_type, - guidance_type="classifier-free", - #condition=conditioning, - #unconditional_condition=unconditional_conditioning, - guidance_scale=unconditional_guidance_scale, - ) - - self.uni_pc = UniPC( - model_fn, - self.noise_schedule, - predict_x0=True, - thresholding=False, - variant=shared.opts.uni_pc_variant, - condition=conditioning, - unconditional_condition=unconditional_conditioning, - before_sample=self.before_sample, - after_sample=self.after_sample, - after_update=self.after_update, - ) - - return self.uni_pc.sample( - x_latent, - steps=self.denoise_steps, - skip_type=shared.opts.uni_pc_skip_type, - method="multistep", - order=shared.opts.schedulers_solver_order, - lower_order_final=shared.opts.schedulers_use_loworder, - denoise_to_zero=True, - timesteps=self.timesteps, - ) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != devices.device: - attr = attr.to(devices.device) - setattr(self, name, attr) - - def set_hooks(self, before_sample, after_sample, after_update): - self.before_sample = before_sample - self.after_sample = after_sample - self.after_update = after_update - - @devices.inference_context() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): - ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - shared.log.warning(f"UniPC: got {cbs} conditionings but batch-size is {batch_size}") - - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - shared.log.warning(f"UniPC: Got {cbs} conditionings but batch-size is {batch_size}") - - else: - if conditioning.shape[0] != batch_size: - shared.log.warning(f"UniPC: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - device = self.model.betas.device - if x_T is None: - img = torch.randn(size, device=device) - else: - img = x_T - - # SD 1.X is "noise", SD 2.X is "v" - model_type = "v" if self.model.parameterization == "v" else "noise" - - model_fn = model_wrapper( - lambda x, t, c: self.model.apply_model(x, t, c), - self.noise_schedule, - model_type=model_type, - guidance_type="classifier-free", - #condition=conditioning, - #unconditional_condition=unconditional_conditioning, - guidance_scale=unconditional_guidance_scale, - ) - - uni_pc = UniPC(model_fn, self.noise_schedule, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) - x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) - - return x.to(device), None diff --git a/modules/unipc/uni_pc.py b/modules/unipc/uni_pc.py deleted file mode 100644 index 6ba3a31fa..000000000 --- a/modules/unipc/uni_pc.py +++ /dev/null @@ -1,779 +0,0 @@ -import torch -import math -import time -from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn -from modules import shared, devices - - -class NoiseScheduleVP: - def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., - ): - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'") - - self.schedule = schedule - if schedule == 'discrete': - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0**2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs=None, - guidance_type="uncond", - #condition=None, - #unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs=None, -): - """Create a wrapper function for the noise prediction model. - - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - - We support four types of the diffusion model by setting `model_type`: - - 1. "noise": noise prediction model. (Trained by predicting noise). - - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - - =============================================================== - - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - model_kwargs = model_kwargs or {} - classifier_kwargs = classifier_kwargs or {} - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, None, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output - - def cond_grad_fn(x, t_input, condition): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous, condition, unconditional_condition): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input, condition) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - if isinstance(condition, dict): - assert isinstance(unconditional_condition, dict) - c_in = {} - for k in condition: - if isinstance(condition[k], list): - c_in[k] = [torch.cat([ - unconditional_condition[k][i], - condition[k][i]]) for i in range(len(condition[k]))] - else: - c_in[k] = torch.cat([ - unconditional_condition[k], - condition[k]]) - elif isinstance(condition, list): - c_in = [] - assert isinstance(unconditional_condition, list) - for i in range(len(condition)): - c_in.append(torch.cat([unconditional_condition[i], condition[i]])) - else: - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - -def get_time_steps(noise_schedule, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - """ - if skip_type == 'logSNR': - lambda_T = noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': - t_order = 2 - t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'") - -class UniPC: - def __init__( - self, - model_fn, - noise_schedule, - predict_x0=True, - thresholding=False, - max_val=1., - variant='bh1', - condition=None, - unconditional_condition=None, - before_sample=None, - after_sample=None, - after_update=None - ): - """Construct a UniPC. - - We support both data_prediction and noise_prediction. - """ - self.model_fn_ = model_fn - self.noise_schedule = noise_schedule - self.variant = variant - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val - self.condition = condition - self.unconditional_condition = unconditional_condition - self.before_sample = before_sample - self.after_sample = after_sample - self.after_update = after_update - - def dynamic_thresholding_fn(self, x0, t=None): - """ - The dynamic thresholding method. - """ - dims = x0.dim() - p = self.dynamic_thresholding_ratio - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model(self, x, t): - cond = self.condition - uncond = self.unconditional_condition - if self.before_sample is not None: - x, t, cond, uncond = self.before_sample(x, t, cond, uncond) - res = self.model_fn_(x, t, cond, uncond) - if self.after_sample is not None: - x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res) - - if isinstance(res, tuple): - # (None, pred_x0) - res = res[1] - - return res - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with thresholding). - """ - noise = self.noise_prediction_fn(x, t) - dims = x.dim() - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.predict_x0: - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3,] * (K - 2) + [2, 1] - elif steps % 3 == 1: - orders = [3,] * (K - 1) + [1] - else: - orders = [3,] * (K - 1) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [2,] * K - else: - K = steps // 2 + 1 - orders = [2,] * (K - 1) + [1] - elif order == 1: - K = steps - orders = [1,] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': - # To reproduce the results in DPM-Solver paper - timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, K, device) - else: - timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): - if len(t.shape) == 0: - t = t.view(-1) - if 'bh' in self.variant: - return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) - else: - assert self.variant == 'vary_coeff' - return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) - - def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): - ns = self.noise_schedule - assert order <= len(model_prev_list) - - # first compute rks - t_prev_0 = t_prev_list[-1] - lambda_prev_0 = ns.marginal_lambda(t_prev_0) - lambda_t = ns.marginal_lambda(t) - model_prev_0 = model_prev_list[-1] - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - log_alpha_t = ns.marginal_log_mean_coeff(t) - alpha_t = torch.exp(log_alpha_t) - - h = lambda_t - lambda_prev_0 - - rks = [] - D1s = [] - for i in range(1, order): - t_prev_i = t_prev_list[-(i + 1)] - model_prev_i = model_prev_list[-(i + 1)] - lambda_prev_i = ns.marginal_lambda(t_prev_i) - rk = (lambda_prev_i - lambda_prev_0) / h - rks.append(rk) - D1s.append((model_prev_i - model_prev_0) / rk) - - rks.append(1.) - rks = torch.tensor(rks, device=x.device) - - K = len(rks) - # build C matrix - C = [] - - col = torch.ones_like(rks) - for k in range(1, K + 1): - C.append(col) - col = col * rks / (k + 1) - C = torch.stack(C, dim=1) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) # (B, K) - C_inv_p = torch.linalg.inv(C[:-1, :-1]) - A_p = C_inv_p - - if use_corrector: - C_inv = torch.linalg.inv(C) - A_c = C_inv - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) - h_phi_ks = [] - factorial_k = 1 - h_phi_k = h_phi_1 - for k in range(1, K + 2): - h_phi_ks.append(h_phi_k) - h_phi_k = h_phi_k / hh - 1 / factorial_k - factorial_k *= (k + 1) - - model_t = None - if self.predict_x0: - x_t_ = ( - sigma_t / sigma_prev_0 * x - - alpha_t * h_phi_1 * model_prev_0 - ) - # now predictor - x_t = x_t_ - if len(D1s) > 0: - # compute the residuals for predictor - for k in range(K - 1): - x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) - # now corrector - if use_corrector: - model_t = self.model_fn(x_t, t) - D1_t = (model_t - model_prev_0) - x_t = x_t_ - k = 0 - for k in range(K - 1): - x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) - x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) - else: - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - x_t_ = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * h_phi_1) * model_prev_0 - ) - # now predictor - x_t = x_t_ - if len(D1s) > 0: - # compute the residuals for predictor - for k in range(K - 1): - x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) - # now corrector - if use_corrector: - model_t = self.model_fn(x_t, t) - D1_t = (model_t - model_prev_0) - x_t = x_t_ - k = 0 - for k in range(K - 1): - x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) - x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) - return x_t, model_t - - def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): - ns = self.noise_schedule - assert order <= len(model_prev_list) - dims = x.dim() - - # first compute rks - t_prev_0 = t_prev_list[-1] - lambda_prev_0 = ns.marginal_lambda(t_prev_0) - lambda_t = ns.marginal_lambda(t) - model_prev_0 = model_prev_list[-1] - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - alpha_t = torch.exp(log_alpha_t) - - h = lambda_t - lambda_prev_0 - - rks = [] - D1s = [] - for i in range(1, order): - t_prev_i = t_prev_list[-(i + 1)] - model_prev_i = model_prev_list[-(i + 1)] - lambda_prev_i = ns.marginal_lambda(t_prev_i) - rk = ((lambda_prev_i - lambda_prev_0) / h)[0] - rks.append(rk) - D1s.append((model_prev_i - model_prev_0) / rk) - - rks.append(1.) - rks = torch.tensor(rks, device=x.device) - - R = [] - b = [] - - hh = -h[0] if self.predict_x0 else h[0] - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.variant == 'bh1': - B_h = hh - elif self.variant == 'bh2': - B_h = torch.expm1(hh) - else: - raise NotImplementedError - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= (i + 1) - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=x.device) - - # now predictor - use_predictor = len(D1s) > 0 and x_t is None - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) # (B, K) - if x_t is None: - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], device=b.device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) - else: - D1s = None - - if use_corrector: - # for order 1, we use a simplified version - if order == 1: - rhos_c = torch.tensor([0.5], device=b.device) - else: - rhos_c = torch.linalg.solve(R, b) - - model_t = None - if self.predict_x0: - x_t_ = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0 - ) - - if x_t is None: - if use_predictor: - pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) - else: - pred_res = 0 - x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res - - if use_corrector: - model_t = self.model_fn(x_t, t) - if D1s is not None: - corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = (model_t - model_prev_0) - x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) - else: - x_t_ = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 - ) - if x_t is None: - if use_predictor: - pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) - else: - pred_res = 0 - x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res - - if use_corrector: - model_t = self.model_fn(x_t, t) - if D1s is not None: - corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = (model_t - model_prev_0) - x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) - return x_t, model_t - - - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, corrector=False, timesteps=None, - ): - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - device = x.device - if method == 'multistep': - if timesteps is None: - timesteps = get_time_steps(self.noise_schedule, skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert steps >= order, "UniPC order must be < sampling steps" - assert timesteps.shape[0] - 1 == steps - with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=shared.console) as progress: - task = progress.add_task(description="Initializing", total=steps) - t = time.time() - with devices.inference_context(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in range(1, order): - vec_t = timesteps[init_order].expand(x.shape[0]) - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) - if model_x is None: - model_x = self.model_fn(x, vec_t) - if self.after_update is not None: - self.after_update(x, model_x) - model_prev_list.append(model_x) - t_prev_list.append(vec_t) - progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * init_order / (time.time() - t), 2)}it/s") - # for step in trange(order, steps + 1): - for step in range(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final: - step_order = min(order, steps + 1 - step) - else: - step_order = order - if step == steps: - use_corrector = False - else: - use_corrector = True - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) - if self.after_update is not None: - self.after_update(x, model_x) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - if model_x is None: - model_x = self.model_fn(x, vec_t) - model_prev_list[-1] = model_x - progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * step / (time.time() - t), 2)}it/s") - else: - raise NotImplementedError - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x - - -############################################################# -# other utility functions -############################################################# - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,)*(dims - 1)] diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 05f007ba0..1fbf95768 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -2,40 +2,10 @@ import math import gradio as gr from modules import images, scripts_manager from modules.processing import process_images -from modules.shared import opts, state, log +from modules.shared import opts, log import modules.sd_samplers -def draw_xy_grid(xs, ys, x_label, y_label, cell): - res = [] - - ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] - hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] - - first_processed = None - - state.job_count = len(xs) * len(ys) - - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed, _t = cell(x, y) - if first_processed is None: - first_processed = processed - - res.append(processed.images[0]) - - if images.check_grid_size(res): - grid = images.image_grid(res, rows=len(ys)) - grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - first_processed.images = [grid] - else: - first_processed.images = res - - return first_processed - - class Script(scripts_manager.Script): def title(self): return "Prompt matrix" diff --git a/webui.py b/webui.py index 12d1b285a..236b220f2 100644 --- a/webui.py +++ b/webui.py @@ -31,7 +31,7 @@ import modules.upscaler import modules.upscaler_simple import modules.extra_networks import modules.ui_extra_networks -import modules.textual_inversion.textual_inversion +import modules.textual_inversion import modules.script_callbacks import modules.api.middleware