diff --git a/.pylintrc b/.pylintrc
index 0fb6ba8dc..b2c67d5b2 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -33,7 +33,6 @@ ignore-paths=/usr/lib/.*$,
modules/taesd,
modules/teacache,
modules/todo,
- modules/unipc,
pipelines/flex2,
pipelines/hidream,
pipelines/meissonic,
diff --git a/modules/devices.py b/modules/devices.py
index 8d15fd238..83ec573e4 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -159,13 +159,6 @@ def get_gpu_info():
return { 'error': ex }
-def extract_device_id(args, name): # pylint: disable=redefined-outer-name
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
- return None
-
-
def get_cuda_device_string():
from modules.shared import cmd_opts
if backend == 'ipex':
@@ -196,13 +189,6 @@ def get_optimal_device():
return torch.device(get_optimal_device_name())
-def get_device_for(task): # pylint: disable=unused-argument
- # if task in cmd_opts.use_cpu:
- # log.debug(f'Forcing CPU for task: {task}')
- # return cpu
- return get_optimal_device()
-
-
def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
def get_stats():
mem_dict = memstats.memory_stats()
@@ -583,14 +569,6 @@ def set_cuda_params():
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
-def cond_cast_unet(tensor):
- return tensor.to(dtype_unet) if unet_needs_upcast else tensor
-
-
-def cond_cast_float(tensor):
- return tensor.float() if unet_needs_upcast else tensor
-
-
def randn(seed, shape=None):
torch.manual_seed(seed)
if backend == 'ipex':
diff --git a/modules/errors.py b/modules/errors.py
index b505a1d0e..81cfe9379 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -15,12 +15,6 @@ def install(suppress=[]):
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
-def print_error_explanation(message):
- lines = message.strip().split("\n")
- for line in lines:
- log.error(line)
-
-
def display(e: Exception, task: str, suppress=[]):
log.error(f"{task or 'error'}: {type(e).__name__}")
console = get_console()
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 6393c9051..75893417e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -138,15 +138,6 @@ def should_skip(param):
return skip
-def bind_buttons(buttons, image_component, send_generate_info):
- """old function for backwards compatibility; do not use this, use register_paste_params_button"""
- for tabname, button in buttons.items():
- source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
- source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
- bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=image_component, source_tabname=source_tabname)
- register_paste_params_button(bindings)
-
-
def register_paste_params_button(binding: ParamBinding):
registered_param_bindings.append(binding)
@@ -194,17 +185,6 @@ def send_image(x):
return image
-def send_image_and_dimensions(x):
- image = x if isinstance(x, Image.Image) else image_from_url_text(x)
- if shared.opts.send_size and isinstance(image, Image.Image):
- w = image.width
- h = image.height
- else:
- w = gr.update()
- h = gr.update()
- return image, w, h
-
-
def create_override_settings_dict(text_pairs):
res = {}
params = {}
diff --git a/modules/infotext.py b/modules/infotext.py
index 50d443c19..6eef31897 100644
--- a/modules/infotext.py
+++ b/modules/infotext.py
@@ -28,40 +28,6 @@ def unquote(text):
return text
-# disabled by default can be enabled if needed
-def check_lora(params):
- try:
- from modules.lora import lora_load
- from modules.errors import log # pylint: disable=redefined-outer-name
- except Exception:
- return
- loras = [s.strip() for s in params.get('LoRA hashes', '').split(',')]
- found = []
- missing = []
- for l in loras:
- lora = lora_load.available_network_hash_lookup.get(l, None)
- if lora is not None:
- found.append(lora.name)
- else:
- missing.append(l)
- loras = [s.strip() for s in params.get('LoRA networks', '').split(',')]
- for l in loras:
- lora = lora_load.available_network_aliases.get(l, None)
- if lora is not None:
- found.append(lora.name)
- else:
- missing.append(l)
- # networks.available_network_aliases.get(name, None)
- loras = re_lora.findall(params.get('Prompt', ''))
- for l in loras:
- lora = lora_load.available_network_aliases.get(l, None)
- if lora is not None:
- found.append(lora.name)
- else:
- missing.append(l)
- log.debug(f'LoRA: found={list(set(found))} missing={list(set(missing))}')
-
-
def parse(infotext):
if not isinstance(infotext, str):
return {}
@@ -115,7 +81,6 @@ def parse(infotext):
params[key] = val
debug(f'Param parsed: type={type(params[key])} "{key}"={params[key]} raw="{val}"')
- # check_lora(params)
return params
diff --git a/modules/memstats.py b/modules/memstats.py
index f226f3c83..5af35d6be 100644
--- a/modules/memstats.py
+++ b/modules/memstats.py
@@ -84,10 +84,6 @@ def reset_stats():
pass
-def memory_cache():
- return mem
-
-
def ram_stats():
try:
process = psutil.Process(os.getpid())
diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py
index 008da4144..162f9db22 100644
--- a/modules/sd_checkpoint.py
+++ b/modules/sd_checkpoint.py
@@ -160,11 +160,8 @@ def list_models():
def update_model_hashes():
txt = []
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
- # shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
for ckpt in lst:
ckpt.hash = model_hash(ckpt.filename)
- # txt.append(f'Calculated short hash: {ckpt.title} {ckpt.hash}')
- # txt.append(f'Updated short hashes for {len(lst)} out of {len(checkpoints_list)} models')
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
for ckpt in lst:
@@ -214,15 +211,6 @@ def get_closet_checkpoint_match(s: str) -> CheckpointInfo:
checkpoint_info = CheckpointInfo(s)
return checkpoint_info
- # reference search
- """
- found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path']))
- if found and len(found) == 1:
- checkpoint_info = CheckpointInfo(found[0]['path']) # create a virutal model info
- checkpoint_info.type = 'huggingface'
- return checkpoint_info
- """
-
# huggingface search
if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1:
modelloader.hf_login()
@@ -251,13 +239,10 @@ def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
- # t0 = time.time()
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
shorthash = m.hexdigest()[0:8]
- # t1 = time.time()
- # shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
return shorthash
except FileNotFoundError:
return 'NOFILE'
@@ -280,14 +265,11 @@ def select_checkpoint(op='model'):
shared.log.info(" or use --ckpt-dir to specify folder with sd models")
shared.log.info(" or use --ckpt to force using specific model")
return None
- # checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0':
shared.log.info(f'Load {op}: search="{model_checkpoint}" not found')
else:
shared.log.info("Selecting first available checkpoint")
- # shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
- # shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
else:
shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
return checkpoint_info
@@ -367,8 +349,6 @@ def read_metadata_from_safetensors(filename):
t1 = time.time()
global sd_metadata_timer # pylint: disable=global-statement
sd_metadata_timer += (t1 - t0)
- # except Exception as e:
- # shared.log.error(f"Error reading metadata from: {filename} {e}")
return res
diff --git a/modules/sd_models.py b/modules/sd_models.py
index cf57e4c5d..011d75c6f 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -620,7 +620,7 @@ def load_diffuser(checkpoint_info=None, timer=None, op='model', revision=None):
if debug_load:
shared.log.trace(f'Model components: {list(get_signature(sd_model).values())}')
- from modules.textual_inversion import textual_inversion
+ from modules import textual_inversion
sd_model.embedding_db = textual_inversion.EmbeddingDatabase()
sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
diff --git a/modules/sd_samplers_diffusers.py b/modules/sd_samplers_diffusers.py
index 8df2d0dbd..b1f393413 100644
--- a/modules/sd_samplers_diffusers.py
+++ b/modules/sd_samplers_diffusers.py
@@ -319,9 +319,6 @@ class DiffusionSampler:
return
self.sampler = sampler
- if name == 'DC Solver':
- if not hasattr(self.sampler, 'dc_ratios'):
- pass
# shared.log.debug_log(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}')
self.sampler.name = name
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 88eebcdb0..b6c3982e9 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,6 +1,5 @@
import os
import glob
-from copy import deepcopy
import torch
from modules import shared, errors, paths, devices, sd_models, sd_detect
@@ -12,35 +11,13 @@ loaded_vae_file = None
checkpoint_info = None
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
+unspecified = object()
-def get_base_vae(model):
- if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
- return base_vae
- return None
-
-
-def store_base_vae(model):
- global base_vae, checkpoint_info # pylint: disable=global-statement
- if checkpoint_info != model.sd_checkpoint_info:
- assert not loaded_vae_file, "Trying to store non-base VAE!"
- base_vae = deepcopy(model.first_stage_model.state_dict())
- checkpoint_info = model.sd_checkpoint_info
-
-
-def delete_base_vae():
- global base_vae, checkpoint_info # pylint: disable=global-statement
- base_vae = None
- checkpoint_info = None
-
-
-def restore_base_vae(model):
- global loaded_vae_file # pylint: disable=global-statement
- if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
- shared.log.info("Restoring base VAE")
- _load_vae_dict(model, base_vae)
- loaded_vae_file = None
- delete_base_vae()
+def load_vae_dict(filename):
+ vae_ckpt = sd_models.read_state_dict(filename, what='vae')
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ return vae_dict_1
def get_filename(filepath):
@@ -114,35 +91,6 @@ def resolve_vae(checkpoint_file):
return None, None
-def load_vae_dict(filename):
- vae_ckpt = sd_models.read_state_dict(filename, what='vae')
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- return vae_dict_1
-
-
-def load_vae(model, vae_file=None, vae_source="unknown-source"):
- global loaded_vae_file # pylint: disable=global-statement
- if vae_file:
- try:
- if not os.path.isfile(vae_file):
- shared.log.error(f"VAE not found: model={vae_file} source={vae_source}")
- return
- store_base_vae(model)
- vae_dict_1 = load_vae_dict(vae_file)
- _load_vae_dict(model, vae_dict_1)
- except Exception as e:
- shared.log.error(f"Load VAE failed: model={vae_file} source={vae_source} {e}")
- if debug:
- errors.display(e, 'VAE')
- restore_base_vae(model)
- vae_opt = get_filename(vae_file)
- if vae_opt not in vae_dict:
- vae_dict[vae_opt] = vae_file
- elif loaded_vae_file:
- restore_base_vae(model)
- loaded_vae_file = vae_file
-
-
def apply_vae_config(model_file, vae_file, sd_model):
def get_vae_config():
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(model_file))[0] + '_vae.json')
@@ -219,20 +167,6 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"):
return None
-# don't call this from outside
-def _load_vae_dict(model, vae_dict_1):
- model.first_stage_model.load_state_dict(vae_dict_1)
- model.first_stage_model.to(devices.dtype_vae)
-
-
-def clear_loaded_vae():
- global loaded_vae_file # pylint: disable=global-statement
- loaded_vae_file = None
-
-
-unspecified = object()
-
-
def reload_vae_weights(sd_model=None, vae_file=unspecified):
if not sd_model:
sd_model = shared.sd_model
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion.py
similarity index 86%
rename from modules/textual_inversion/textual_inversion.py
rename to modules/textual_inversion.py
index c24f6cbfd..f6b1c558d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion.py
@@ -3,9 +3,7 @@ import os
import time
import torch
import safetensors.torch
-from PIL import Image
from modules import shared, devices, errors
-from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed
from modules.files_cache import directory_files, directory_mtime, extension_filter
@@ -241,9 +239,6 @@ class EmbeddingDatabase:
def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
- def clear_embedding_dirs(self):
- self.embedding_dirs.clear()
-
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
if hasattr(model, 'cond_stage_model'):
@@ -306,45 +301,6 @@ class EmbeddingDatabase:
errors.display(e, f'Load embedding: name="{embedding.name}" file="{embedding.filename}"')
return
- def load_from_file(self, path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- if '.preview' in filename.lower():
- return
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- else:
- data = extract_image_data_embed(embed_image)
- if not data: # if data is None, means this is not an embeding, just a preview image
- return
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- elif ext in ['.SAFETENSORS']:
- data = safetensors.torch.load_file(path, device="cpu")
- else:
- return
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- if len(data.keys()) != 1:
- self.skipped_embeddings[name] = Embedding(None, name=name, filename=path)
- return
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding")
-
-
def load_from_dir(self, embdir):
if not shared.sd_loaded:
shared.log.info('Skipping embeddings load: model not loaded')
@@ -387,14 +343,3 @@ class EmbeddingDatabase:
self.previously_displayed_embeddings = displayed_embeddings
t1 = time.time()
shared.log.info(f"Network load: type=embeddings loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}")
-
-
- def find_embedding_at_position(self, tokens, offset):
- token = tokens[offset]
- possible_matches = self.ids_lookup.get(token, None)
- if possible_matches is None:
- return None, None
- for ids, embedding in possible_matches:
- if tokens[offset:offset + len(ids)] == ids:
- return embedding, len(ids)
- return None, None
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
deleted file mode 100644
index 74eb88fd2..000000000
--- a/modules/textual_inversion/image_embedding.py
+++ /dev/null
@@ -1,213 +0,0 @@
-import base64
-import json
-import zlib
-import numpy as np
-from PIL import Image, ImageDraw, ImageFont
-import torch
-from modules.shared import opts
-
-
-class EmbeddingEncoder(json.JSONEncoder):
- def default(self, o):
- if isinstance(o, torch.Tensor):
- return {'TORCHTENSOR': o.cpu().detach().numpy().tolist()}
- return json.JSONEncoder.default(self, o)
-
-
-class EmbeddingDecoder(json.JSONDecoder):
- def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
-
- def object_hook(self, d): # pylint: disable=E0202
- if 'TORCHTENSOR' in d:
- return torch.from_numpy(np.array(d['TORCHTENSOR']))
- return d
-
-
-def embedding_to_b64(data):
- d = json.dumps(data, cls=EmbeddingEncoder)
- return base64.b64encode(d.encode())
-
-
-def embedding_from_b64(data):
- d = base64.b64decode(data)
- return json.loads(d, cls=EmbeddingDecoder)
-
-
-def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
- while True:
- seed = (a * seed + c) % m
- yield seed % 255
-
-
-def xor_block(block):
- blk = lcg()
- randblock = np.array([next(blk) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
- return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
-
-
-def style_block(block, sequence):
- im = Image.new('RGB', (block.shape[1], block.shape[0]))
- draw = ImageDraw.Draw(im)
- i = 0
- for x in range(-6, im.size[0], 8):
- for yi, y in enumerate(range(-6, im.size[1], 8)):
- offset = 0
- if yi % 2 == 0:
- offset = 4
- shade = sequence[i % len(sequence)]
- i += 1
- draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
-
- fg = np.array(im).astype(np.uint8) & 0xF0
-
- return block ^ fg
-
-
-def insert_image_data_embed(image, data):
- d = 3
- data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
- data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
- data_np_high = data_np_ >> 4
- data_np_low = data_np_ & 0x0F
-
- h = image.size[1]
- next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
- next_size = next_size + ((h*d)-(next_size % (h*d)))
-
- data_np_low = np.resize(data_np_low, next_size)
- data_np_low = data_np_low.reshape((h, -1, d))
-
- data_np_high = np.resize(data_np_high, next_size)
- data_np_high = data_np_high.reshape((h, -1, d))
-
- edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
- edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
-
- data_np_low = style_block(data_np_low, sequence=edge_style)
- data_np_low = xor_block(data_np_low)
- data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
- data_np_high = xor_block(data_np_high)
-
- im_low = Image.fromarray(data_np_low, mode='RGB')
- im_high = Image.fromarray(data_np_high, mode='RGB')
-
- background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
- background.paste(im_low, (0, 0))
- background.paste(image, (im_low.size[0]+1, 0))
- background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
-
- return background
-
-
-def crop_black(img, tol=0):
- mask = (img > tol).all(2)
- mask0, mask1 = mask.any(0), mask.any(1)
- col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
- row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
- return img[row_start:row_end, col_start:col_end]
-
-
-def extract_image_data_embed(image):
- d = 3
- outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F # pylint: disable=E1121
- black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
- if black_cols[0].shape[0] < 2:
- return None
-
- data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
- data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
-
- data_block_lower = xor_block(data_block_lower)
- data_block_upper = xor_block(data_block_upper)
-
- data_block = (data_block_upper << 4) | (data_block_lower)
- data_block = data_block.flatten().tobytes()
-
- data = zlib.decompress(data_block)
- return json.loads(data, cls=EmbeddingDecoder)
-
-
-def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
- from math import cos
- image = srcimage.copy()
- fontsize = 32
- if textfont is None:
- textfont = opts.font or 'javascript/notosans-nerdfont-regular.ttf'
-
- factor = 1.5
- gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
- for y in range(image.size[1]):
- mag = 1-cos(y/image.size[1]*factor)
- mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
- gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
- image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
-
- draw = ImageDraw.Draw(image)
-
- font = ImageFont.truetype(textfont, fontsize)
- padding = 10
-
- _, _, w, _h = draw.textbbox((0, 0), title, font=font)
- fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
- font = ImageFont.truetype(textfont, fontsize)
- _, _, w, _h = draw.textbbox((0, 0), title, font=font)
- draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
-
- _, _, w, _h = draw.textbbox((0, 0), footerLeft, font=font)
- fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- _, _, w, _h = draw.textbbox((0, 0), footerMid, font=font)
- fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- _, _, w, _h = draw.textbbox((0, 0), footerRight, font=font)
- fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
-
- font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
-
- draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
- draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
- draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
-
- return image
-
-
-if __name__ == '__main__':
-
- testEmbed = Image.open('test_embedding.png')
- test_data = extract_image_data_embed(testEmbed)
- assert test_data is not None
-
- test_data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
- assert test_data is not None
-
- new_image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
- cap_image = caption_image_overlay(new_image, 'title', 'footerLeft', 'footerMid', 'footerRight')
-
- test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
-
- embedded_image = insert_image_data_embed(cap_image, test_embed)
-
- retrived_embed = extract_image_data_embed(embedded_image)
-
- assert str(retrived_embed) == str(test_embed)
-
- embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
-
- assert embedded_image == embedded_image2
-
- g = lcg()
- shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
-
- reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
- 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
- 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
- 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
- 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
- 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
- 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
- 204, 86, 73, 222, 44, 198, 118, 240, 97]
-
- assert shared_random == reference_random
-
- hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
-
- assert 12731374 == hunna_kay_random_sum
diff --git a/modules/ui.py b/modules/ui.py
index be52a04b0..f3f942636 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -67,10 +67,6 @@ def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument
pass
-def ordered_ui_categories():
- return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
-
-
def create_ui(startup_timer = None):
if startup_timer is None:
timer.startup = timer.Timer()
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 762b99701..44a4c816c 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -323,19 +323,6 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args = No
return refresh_button
-def create_browse_button(browse_component, elem_id):
- def browse(folder):
- # import subprocess
- if folder is not None:
- return gr.update(value = folder)
- return gr.update()
-
- browse_button = ui_components.ToolButton(value=ui_symbols.folder, elem_id=elem_id)
- browse_button.click(fn=browse, _js="async () => await browseFolder()", inputs=[browse_component], outputs=[browse_component])
- # browse_button.click(fn=browse, inputs=[browse_component], outputs=[browse_component])
- return browse_button
-
-
def create_override_inputs(tab): # pylint: disable=unused-argument
with gr.Row(elem_id=f"{tab}_override_settings_row"):
override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True)
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 378f6a0db..bc00a82ab 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -69,10 +69,6 @@ def list_extensions():
debug(f'Extension installed without index: {entry}')
-def check_access():
- assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
-
-
def apply_changes(disable_list, update_list, disable_all):
if shared.cmd_opts.disable_extension_access:
shared.log.error('Extension: apply changes disallowed because public access is enabled and insecure is not specified')
@@ -126,18 +122,6 @@ def check_updates(_id_task, disable_list, search_text, sort_column):
return create_html(search_text, sort_column), "Extension update complete | Restart required"
-def make_commit_link(commit_hash, remote, text=None):
- if text is None:
- text = commit_hash[:8]
- if remote.startswith("https://github.com/"):
- if remote.endswith(".git"):
- remote = remote[:-4]
- href = remote + "/commit/" + commit_hash
- return f'{text}'
- else:
- return text
-
-
def normalize_git_url(url):
if url is None:
return ""
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 43837ce2d..eef1c3a52 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -162,9 +162,6 @@ class ExtraNetworksPage:
preview = f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
return preview
- def is_empty(self, folder):
- return any(files_cache.list_files(folder, ext_filter=['.ckpt', '.safetensors', '.pt', '.json']))
-
def create_thumb(self):
debug(f'EN create-thumb: {self.name}')
created = 0
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index ea8067fea..55b5b01d1 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,7 +1,7 @@
import json
import os
from modules import shared, sd_models, ui_extra_networks, files_cache
-from modules.textual_inversion.textual_inversion import Embedding
+from modules.textual_inversion import Embedding
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
diff --git a/modules/ui_javascript.py b/modules/ui_javascript.py
index 7bbe59988..6668d33c5 100644
--- a/modules/ui_javascript.py
+++ b/modules/ui_javascript.py
@@ -115,20 +115,5 @@ def reload_javascript():
gradio.routes.templates.TemplateResponse = template_response
-def setup_ui_api(app):
- from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
- from typing import List
-
- class QuicksettingsHint(BaseModel): # pylint: disable=too-few-public-methods
- name: str = Field(title="Name of the quicksettings field")
- label: str = Field(title="Label of the quicksettings field")
-
- def quicksettings_hint():
- return [QuicksettingsHint(name=k, label=v.label) for k, v in shared.opts.data_labels.items()]
-
- app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
- app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
-
-
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
diff --git a/modules/ui_sections.py b/modules/ui_sections.py
index 7edb31a1c..a0a1b07bd 100644
--- a/modules/ui_sections.py
+++ b/modules/ui_sections.py
@@ -99,20 +99,6 @@ def create_interrogate_button(tab: str, inputs: list = None, outputs: str = None
return button_interrogate
-def create_interrogate_buttons(tab): # legacy function
- button_interrogate = gr.Button(ui_symbols.int_clip, elem_id=f"{tab}_interrogate", elem_classes=['interrogate-clip'])
- button_deepbooru = gr.Button(ui_symbols.int_blip, elem_id=f"{tab}_deepbooru", elem_classes=['interrogate-blip'])
- return button_interrogate, button_deepbooru
-
-
-def create_sampler_inputs(tab, accordion=True):
- with gr.Accordion(open=False, label="Sampler", elem_id=f"{tab}_sampler", elem_classes=["small-accordion"]) if accordion else gr.Group():
- with gr.Row(elem_id=f"{tab}_row_sampler"):
- sd_samplers.set_samplers()
- steps, sampler_index = create_sampler_and_steps_selection(sd_samplers.samplers, tab)
- return steps, sampler_index
-
-
def create_batch_inputs(tab, accordion=True):
with gr.Accordion(open=False, label="Batch", elem_id=f"{tab}_batch", elem_classes=["small-accordion"]) if accordion else gr.Group():
with gr.Row(elem_id=f"{tab}_row_batch"):
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 7a7b75430..b53f7bb0e 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -48,10 +48,6 @@ def get_value_for_setting(key):
return gr.update(value=value, **args)
-def ordered_ui_categories():
- return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
-
-
def create_setting_component(key, is_quicksettings=False):
def fun():
return shared.opts.data[key] if key in shared.opts.data else shared.opts.data_labels[key].default
@@ -88,7 +84,6 @@ def create_setting_component(key, is_quicksettings=False):
elif info.folder is not None:
with gr.Row():
res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
- # ui_common.create_browse_button(res, f"folder_{key}")
else:
try:
res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
diff --git a/modules/ui_txt2img.py b/modules/ui_txt2img.py
index a5358dd1f..ecc0678ae 100644
--- a/modules/ui_txt2img.py
+++ b/modules/ui_txt2img.py
@@ -1,19 +1,9 @@
import gradio as gr
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
-from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing, processing_vae, devices, images
+from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing_vae, images
from modules.ui_components import ToolButton # pylint: disable=unused-import
-def calc_resolution_hires(width, height, hr_scale, hr_resize_x, hr_resize_y, hr_upscaler):
- if hr_upscaler == "None":
- return "Hires resize: None"
- p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
- p.init_hr()
- with devices.autocast():
- p.init([""], [0], [0])
- return f"Hires resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
-
-
def create_ui():
shared.log.debug('UI initialize: txt2img')
import modules.txt2img # pylint: disable=redefined-outer-name
diff --git a/modules/unipc/__init__.py b/modules/unipc/__init__.py
deleted file mode 100644
index e1265e3fe..000000000
--- a/modules/unipc/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .sampler import UniPCSampler
diff --git a/modules/unipc/sampler.py b/modules/unipc/sampler.py
deleted file mode 100644
index b5e116d61..000000000
--- a/modules/unipc/sampler.py
+++ /dev/null
@@ -1,191 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-
-from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC, get_time_steps
-from modules import shared, devices
-
-
-class UniPCSampler(object):
- def __init__(self, model, **kwargs):
- super().__init__()
- self.model = model
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
- self.before_sample = None
- self.after_sample = None
- self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
-
- self.noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
-
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
- # persist steps so we can eventually find denoising strength
- self.inflated_steps = ddim_num_steps
-
- @devices.inference_context()
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
- if noise is None:
- noise = torch.randn_like(x0)
-
- # first time we have all the info to get the real parameters from the ui
- # value from the hires steps slider:
- num_inference_steps = t[0] + 1
- num_inference_steps / self.inflated_steps
- self.denoise_steps = max(num_inference_steps, shared.opts.schedulers_solver_order)
-
- max(self.inflated_steps - self.denoise_steps, 0)
-
- # actual number of steps we'll run
-
- all_timesteps = get_time_steps(
- self.noise_schedule,
- shared.opts.uni_pc_skip_type,
- self.noise_schedule.T,
- 1./self.noise_schedule.total_N,
- self.inflated_steps+1,
- t.device,
- )
-
- # the rest of the timesteps will be used for denoising
- self.timesteps = all_timesteps[-(self.denoise_steps+1):]
-
- latent_timestep = (
- ( # get the timestep of our first denoise step
- self.timesteps[:1]
- # multiply by number of alphas to get int index
- * self.noise_schedule.total_N
- ).int() - 1 # minus one for 0-indexed
- ).repeat(x0.shape[0])
-
- alphas_cumprod = self.alphas_cumprod
- sqrt_alpha_prod = alphas_cumprod[latent_timestep] ** 0.5
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
- while len(sqrt_alpha_prod.shape) < len(x0.shape):
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
-
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[latent_timestep]) ** 0.5
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
- while len(sqrt_one_minus_alpha_prod.shape) < len(x0.shape):
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
-
- return (sqrt_alpha_prod * x0 + sqrt_one_minus_alpha_prod * noise)
-
- def decode(self, x_latent, conditioning, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- use_original_steps=False, callback=None):
- # same as in .sample(), i guess
- model_type = "v" if self.model.parameterization == "v" else "noise"
-
- model_fn = model_wrapper(
- lambda x, t, c: self.model.apply_model(x, t, c),
- self.noise_schedule,
- model_type=model_type,
- guidance_type="classifier-free",
- #condition=conditioning,
- #unconditional_condition=unconditional_conditioning,
- guidance_scale=unconditional_guidance_scale,
- )
-
- self.uni_pc = UniPC(
- model_fn,
- self.noise_schedule,
- predict_x0=True,
- thresholding=False,
- variant=shared.opts.uni_pc_variant,
- condition=conditioning,
- unconditional_condition=unconditional_conditioning,
- before_sample=self.before_sample,
- after_sample=self.after_sample,
- after_update=self.after_update,
- )
-
- return self.uni_pc.sample(
- x_latent,
- steps=self.denoise_steps,
- skip_type=shared.opts.uni_pc_skip_type,
- method="multistep",
- order=shared.opts.schedulers_solver_order,
- lower_order_final=shared.opts.schedulers_use_loworder,
- denoise_to_zero=True,
- timesteps=self.timesteps,
- )
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != devices.device:
- attr = attr.to(devices.device)
- setattr(self, name, attr)
-
- def set_hooks(self, before_sample, after_sample, after_update):
- self.before_sample = before_sample
- self.after_sample = after_sample
- self.after_update = after_update
-
- @devices.inference_context()
- def sample(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- shared.log.warning(f"UniPC: got {cbs} conditionings but batch-size is {batch_size}")
-
- elif isinstance(conditioning, list):
- for ctmp in conditioning:
- if ctmp.shape[0] != batch_size:
- shared.log.warning(f"UniPC: Got {cbs} conditionings but batch-size is {batch_size}")
-
- else:
- if conditioning.shape[0] != batch_size:
- shared.log.warning(f"UniPC: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
-
- device = self.model.betas.device
- if x_T is None:
- img = torch.randn(size, device=device)
- else:
- img = x_T
-
- # SD 1.X is "noise", SD 2.X is "v"
- model_type = "v" if self.model.parameterization == "v" else "noise"
-
- model_fn = model_wrapper(
- lambda x, t, c: self.model.apply_model(x, t, c),
- self.noise_schedule,
- model_type=model_type,
- guidance_type="classifier-free",
- #condition=conditioning,
- #unconditional_condition=unconditional_conditioning,
- guidance_scale=unconditional_guidance_scale,
- )
-
- uni_pc = UniPC(model_fn, self.noise_schedule, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
- x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
-
- return x.to(device), None
diff --git a/modules/unipc/uni_pc.py b/modules/unipc/uni_pc.py
deleted file mode 100644
index 6ba3a31fa..000000000
--- a/modules/unipc/uni_pc.py
+++ /dev/null
@@ -1,779 +0,0 @@
-import torch
-import math
-import time
-from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
-from modules import shared, devices
-
-
-class NoiseScheduleVP:
- def __init__(
- self,
- schedule='discrete',
- betas=None,
- alphas_cumprod=None,
- continuous_beta_0=0.1,
- continuous_beta_1=20.,
- ):
- if schedule not in ['discrete', 'linear', 'cosine']:
- raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
-
- self.schedule = schedule
- if schedule == 'discrete':
- if betas is not None:
- log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
- else:
- assert alphas_cumprod is not None
- log_alphas = 0.5 * torch.log(alphas_cumprod)
- self.total_N = len(log_alphas)
- self.T = 1.
- self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
- self.log_alpha_array = log_alphas.reshape((1, -1,))
- else:
- self.total_N = 1000
- self.beta_0 = continuous_beta_0
- self.beta_1 = continuous_beta_1
- self.cosine_s = 0.008
- self.cosine_beta_max = 999.
- self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
- self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
- self.schedule = schedule
- if schedule == 'cosine':
- # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
- # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
- self.T = 0.9946
- else:
- self.T = 1.
-
- def marginal_log_mean_coeff(self, t):
- """
- Compute log(alpha_t) of a given continuous-time label t in [0, T].
- """
- if self.schedule == 'discrete':
- return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
- elif self.schedule == 'linear':
- return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
- elif self.schedule == 'cosine':
- log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
- log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
- return log_alpha_t
-
- def marginal_alpha(self, t):
- """
- Compute alpha_t of a given continuous-time label t in [0, T].
- """
- return torch.exp(self.marginal_log_mean_coeff(t))
-
- def marginal_std(self, t):
- """
- Compute sigma_t of a given continuous-time label t in [0, T].
- """
- return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
-
- def marginal_lambda(self, t):
- """
- Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
- """
- log_mean_coeff = self.marginal_log_mean_coeff(t)
- log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
- return log_mean_coeff - log_std
-
- def inverse_lambda(self, lamb):
- """
- Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
- """
- if self.schedule == 'linear':
- tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
- Delta = self.beta_0**2 + tmp
- return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
- elif self.schedule == 'discrete':
- log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
- t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
- return t.reshape((-1,))
- else:
- log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
- t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
- t = t_fn(log_alpha)
- return t
-
-
-def model_wrapper(
- model,
- noise_schedule,
- model_type="noise",
- model_kwargs=None,
- guidance_type="uncond",
- #condition=None,
- #unconditional_condition=None,
- guidance_scale=1.,
- classifier_fn=None,
- classifier_kwargs=None,
-):
- """Create a wrapper function for the noise prediction model.
-
- DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
- firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
-
- We support four types of the diffusion model by setting `model_type`:
-
- 1. "noise": noise prediction model. (Trained by predicting noise).
-
- 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
-
- 3. "v": velocity prediction model. (Trained by predicting the velocity).
- The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
-
- [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
- arXiv preprint arXiv:2202.00512 (2022).
- [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
- arXiv preprint arXiv:2210.02303 (2022).
-
- 4. "score": marginal score function. (Trained by denoising score matching).
- Note that the score function and the noise prediction model follows a simple relationship:
- ```
- noise(x_t, t) = -sigma_t * score(x_t, t)
- ```
-
- We support three types of guided sampling by DPMs by setting `guidance_type`:
- 1. "uncond": unconditional sampling by DPMs.
- The input `model` has the following format:
- ``
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
- ``
-
- 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
- The input `model` has the following format:
- ``
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
- ``
-
- The input `classifier_fn` has the following format:
- ``
- classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
- ``
-
- [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
- in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
-
- 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
- The input `model` has the following format:
- ``
- model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
- ``
- And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
-
- [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
- arXiv preprint arXiv:2207.12598 (2022).
-
-
- The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
- or continuous-time labels (i.e. epsilon to T).
-
- We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
- ``
- def model_fn(x, t_continuous) -> noise:
- t_input = get_model_input_time(t_continuous)
- return noise_pred(model, x, t_input, **model_kwargs)
- ``
- where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
-
- ===============================================================
-
- Args:
- model: A diffusion model with the corresponding format described above.
- noise_schedule: A noise schedule object, such as NoiseScheduleVP.
- model_type: A `str`. The parameterization type of the diffusion model.
- "noise" or "x_start" or "v" or "score".
- model_kwargs: A `dict`. A dict for the other inputs of the model function.
- guidance_type: A `str`. The type of the guidance for sampling.
- "uncond" or "classifier" or "classifier-free".
- condition: A pytorch tensor. The condition for the guided sampling.
- Only used for "classifier" or "classifier-free" guidance type.
- unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
- Only used for "classifier-free" guidance type.
- guidance_scale: A `float`. The scale for the guided sampling.
- classifier_fn: A classifier function. Only used for the classifier guidance.
- classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
- Returns:
- A noise prediction model that accepts the noised data and the continuous time as the inputs.
- """
- model_kwargs = model_kwargs or {}
- classifier_kwargs = classifier_kwargs or {}
-
- def get_model_input_time(t_continuous):
- """
- Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
- For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
- For continuous-time DPMs, we just use `t_continuous`.
- """
- if noise_schedule.schedule == 'discrete':
- return (t_continuous - 1. / noise_schedule.total_N) * 1000.
- else:
- return t_continuous
-
- def noise_pred_fn(x, t_continuous, cond=None):
- if t_continuous.reshape((-1,)).shape[0] == 1:
- t_continuous = t_continuous.expand((x.shape[0]))
- t_input = get_model_input_time(t_continuous)
- if cond is None:
- output = model(x, t_input, None, **model_kwargs)
- else:
- output = model(x, t_input, cond, **model_kwargs)
- if model_type == "noise":
- return output
- elif model_type == "x_start":
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
- dims = x.dim()
- return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
- elif model_type == "v":
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
- dims = x.dim()
- return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
- elif model_type == "score":
- sigma_t = noise_schedule.marginal_std(t_continuous)
- dims = x.dim()
- return -expand_dims(sigma_t, dims) * output
-
- def cond_grad_fn(x, t_input, condition):
- """
- Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
- """
- with torch.enable_grad():
- x_in = x.detach().requires_grad_(True)
- log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
- return torch.autograd.grad(log_prob.sum(), x_in)[0]
-
- def model_fn(x, t_continuous, condition, unconditional_condition):
- """
- The noise predicition model function that is used for DPM-Solver.
- """
- if t_continuous.reshape((-1,)).shape[0] == 1:
- t_continuous = t_continuous.expand((x.shape[0]))
- if guidance_type == "uncond":
- return noise_pred_fn(x, t_continuous)
- elif guidance_type == "classifier":
- assert classifier_fn is not None
- t_input = get_model_input_time(t_continuous)
- cond_grad = cond_grad_fn(x, t_input, condition)
- sigma_t = noise_schedule.marginal_std(t_continuous)
- noise = noise_pred_fn(x, t_continuous)
- return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
- elif guidance_type == "classifier-free":
- if guidance_scale == 1. or unconditional_condition is None:
- return noise_pred_fn(x, t_continuous, cond=condition)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t_continuous] * 2)
- if isinstance(condition, dict):
- assert isinstance(unconditional_condition, dict)
- c_in = {}
- for k in condition:
- if isinstance(condition[k], list):
- c_in[k] = [torch.cat([
- unconditional_condition[k][i],
- condition[k][i]]) for i in range(len(condition[k]))]
- else:
- c_in[k] = torch.cat([
- unconditional_condition[k],
- condition[k]])
- elif isinstance(condition, list):
- c_in = []
- assert isinstance(unconditional_condition, list)
- for i in range(len(condition)):
- c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
- else:
- c_in = torch.cat([unconditional_condition, condition])
- noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
- return noise_uncond + guidance_scale * (noise - noise_uncond)
-
- assert model_type in ["noise", "x_start", "v"]
- assert guidance_type in ["uncond", "classifier", "classifier-free"]
- return model_fn
-
-def get_time_steps(noise_schedule, skip_type, t_T, t_0, N, device):
- """Compute the intermediate time steps for sampling.
- """
- if skip_type == 'logSNR':
- lambda_T = noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
- lambda_0 = noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
- logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
- return noise_schedule.inverse_lambda(logSNR_steps)
- elif skip_type == 'time_uniform':
- return torch.linspace(t_T, t_0, N + 1).to(device)
- elif skip_type == 'time_quadratic':
- t_order = 2
- t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
- return t
- else:
- raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
-
-class UniPC:
- def __init__(
- self,
- model_fn,
- noise_schedule,
- predict_x0=True,
- thresholding=False,
- max_val=1.,
- variant='bh1',
- condition=None,
- unconditional_condition=None,
- before_sample=None,
- after_sample=None,
- after_update=None
- ):
- """Construct a UniPC.
-
- We support both data_prediction and noise_prediction.
- """
- self.model_fn_ = model_fn
- self.noise_schedule = noise_schedule
- self.variant = variant
- self.predict_x0 = predict_x0
- self.thresholding = thresholding
- self.max_val = max_val
- self.condition = condition
- self.unconditional_condition = unconditional_condition
- self.before_sample = before_sample
- self.after_sample = after_sample
- self.after_update = after_update
-
- def dynamic_thresholding_fn(self, x0, t=None):
- """
- The dynamic thresholding method.
- """
- dims = x0.dim()
- p = self.dynamic_thresholding_ratio
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
- s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
- x0 = torch.clamp(x0, -s, s) / s
- return x0
-
- def model(self, x, t):
- cond = self.condition
- uncond = self.unconditional_condition
- if self.before_sample is not None:
- x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
- res = self.model_fn_(x, t, cond, uncond)
- if self.after_sample is not None:
- x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
-
- if isinstance(res, tuple):
- # (None, pred_x0)
- res = res[1]
-
- return res
-
- def noise_prediction_fn(self, x, t):
- """
- Return the noise prediction model.
- """
- return self.model(x, t)
-
- def data_prediction_fn(self, x, t):
- """
- Return the data prediction model (with thresholding).
- """
- noise = self.noise_prediction_fn(x, t)
- dims = x.dim()
- alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
- x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
- if self.thresholding:
- p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
- s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
- x0 = torch.clamp(x0, -s, s) / s
- return x0
-
- def model_fn(self, x, t):
- """
- Convert the model to the noise prediction model or the data prediction model.
- """
- if self.predict_x0:
- return self.data_prediction_fn(x, t)
- else:
- return self.noise_prediction_fn(x, t)
-
- def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
- """
- Get the order of each step for sampling by the singlestep DPM-Solver.
- """
- if order == 3:
- K = steps // 3 + 1
- if steps % 3 == 0:
- orders = [3,] * (K - 2) + [2, 1]
- elif steps % 3 == 1:
- orders = [3,] * (K - 1) + [1]
- else:
- orders = [3,] * (K - 1) + [2]
- elif order == 2:
- if steps % 2 == 0:
- K = steps // 2
- orders = [2,] * K
- else:
- K = steps // 2 + 1
- orders = [2,] * (K - 1) + [1]
- elif order == 1:
- K = steps
- orders = [1,] * steps
- else:
- raise ValueError("'order' must be '1' or '2' or '3'.")
- if skip_type == 'logSNR':
- # To reproduce the results in DPM-Solver paper
- timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, K, device)
- else:
- timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
- return timesteps_outer, orders
-
- def denoise_to_zero_fn(self, x, s):
- """
- Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
- """
- return self.data_prediction_fn(x, s)
-
- def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
- if len(t.shape) == 0:
- t = t.view(-1)
- if 'bh' in self.variant:
- return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
- else:
- assert self.variant == 'vary_coeff'
- return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
-
- def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
- ns = self.noise_schedule
- assert order <= len(model_prev_list)
-
- # first compute rks
- t_prev_0 = t_prev_list[-1]
- lambda_prev_0 = ns.marginal_lambda(t_prev_0)
- lambda_t = ns.marginal_lambda(t)
- model_prev_0 = model_prev_list[-1]
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
- log_alpha_t = ns.marginal_log_mean_coeff(t)
- alpha_t = torch.exp(log_alpha_t)
-
- h = lambda_t - lambda_prev_0
-
- rks = []
- D1s = []
- for i in range(1, order):
- t_prev_i = t_prev_list[-(i + 1)]
- model_prev_i = model_prev_list[-(i + 1)]
- lambda_prev_i = ns.marginal_lambda(t_prev_i)
- rk = (lambda_prev_i - lambda_prev_0) / h
- rks.append(rk)
- D1s.append((model_prev_i - model_prev_0) / rk)
-
- rks.append(1.)
- rks = torch.tensor(rks, device=x.device)
-
- K = len(rks)
- # build C matrix
- C = []
-
- col = torch.ones_like(rks)
- for k in range(1, K + 1):
- C.append(col)
- col = col * rks / (k + 1)
- C = torch.stack(C, dim=1)
-
- if len(D1s) > 0:
- D1s = torch.stack(D1s, dim=1) # (B, K)
- C_inv_p = torch.linalg.inv(C[:-1, :-1])
- A_p = C_inv_p
-
- if use_corrector:
- C_inv = torch.linalg.inv(C)
- A_c = C_inv
-
- hh = -h if self.predict_x0 else h
- h_phi_1 = torch.expm1(hh)
- h_phi_ks = []
- factorial_k = 1
- h_phi_k = h_phi_1
- for k in range(1, K + 2):
- h_phi_ks.append(h_phi_k)
- h_phi_k = h_phi_k / hh - 1 / factorial_k
- factorial_k *= (k + 1)
-
- model_t = None
- if self.predict_x0:
- x_t_ = (
- sigma_t / sigma_prev_0 * x
- - alpha_t * h_phi_1 * model_prev_0
- )
- # now predictor
- x_t = x_t_
- if len(D1s) > 0:
- # compute the residuals for predictor
- for k in range(K - 1):
- x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
- # now corrector
- if use_corrector:
- model_t = self.model_fn(x_t, t)
- D1_t = (model_t - model_prev_0)
- x_t = x_t_
- k = 0
- for k in range(K - 1):
- x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
- x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
- else:
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
- x_t_ = (
- (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- - (sigma_t * h_phi_1) * model_prev_0
- )
- # now predictor
- x_t = x_t_
- if len(D1s) > 0:
- # compute the residuals for predictor
- for k in range(K - 1):
- x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
- # now corrector
- if use_corrector:
- model_t = self.model_fn(x_t, t)
- D1_t = (model_t - model_prev_0)
- x_t = x_t_
- k = 0
- for k in range(K - 1):
- x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
- x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
- return x_t, model_t
-
- def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
- ns = self.noise_schedule
- assert order <= len(model_prev_list)
- dims = x.dim()
-
- # first compute rks
- t_prev_0 = t_prev_list[-1]
- lambda_prev_0 = ns.marginal_lambda(t_prev_0)
- lambda_t = ns.marginal_lambda(t)
- model_prev_0 = model_prev_list[-1]
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
- alpha_t = torch.exp(log_alpha_t)
-
- h = lambda_t - lambda_prev_0
-
- rks = []
- D1s = []
- for i in range(1, order):
- t_prev_i = t_prev_list[-(i + 1)]
- model_prev_i = model_prev_list[-(i + 1)]
- lambda_prev_i = ns.marginal_lambda(t_prev_i)
- rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
- rks.append(rk)
- D1s.append((model_prev_i - model_prev_0) / rk)
-
- rks.append(1.)
- rks = torch.tensor(rks, device=x.device)
-
- R = []
- b = []
-
- hh = -h[0] if self.predict_x0 else h[0]
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
- h_phi_k = h_phi_1 / hh - 1
-
- factorial_i = 1
-
- if self.variant == 'bh1':
- B_h = hh
- elif self.variant == 'bh2':
- B_h = torch.expm1(hh)
- else:
- raise NotImplementedError
-
- for i in range(1, order + 1):
- R.append(torch.pow(rks, i - 1))
- b.append(h_phi_k * factorial_i / B_h)
- factorial_i *= (i + 1)
- h_phi_k = h_phi_k / hh - 1 / factorial_i
-
- R = torch.stack(R)
- b = torch.tensor(b, device=x.device)
-
- # now predictor
- use_predictor = len(D1s) > 0 and x_t is None
- if len(D1s) > 0:
- D1s = torch.stack(D1s, dim=1) # (B, K)
- if x_t is None:
- # for order 2, we use a simplified version
- if order == 2:
- rhos_p = torch.tensor([0.5], device=b.device)
- else:
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
- else:
- D1s = None
-
- if use_corrector:
- # for order 1, we use a simplified version
- if order == 1:
- rhos_c = torch.tensor([0.5], device=b.device)
- else:
- rhos_c = torch.linalg.solve(R, b)
-
- model_t = None
- if self.predict_x0:
- x_t_ = (
- expand_dims(sigma_t / sigma_prev_0, dims) * x
- - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
- )
-
- if x_t is None:
- if use_predictor:
- pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
- else:
- pred_res = 0
- x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
-
- if use_corrector:
- model_t = self.model_fn(x_t, t)
- if D1s is not None:
- corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
- else:
- corr_res = 0
- D1_t = (model_t - model_prev_0)
- x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
- else:
- x_t_ = (
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
- )
- if x_t is None:
- if use_predictor:
- pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
- else:
- pred_res = 0
- x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
-
- if use_corrector:
- model_t = self.model_fn(x_t, t)
- if D1s is not None:
- corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
- else:
- corr_res = 0
- D1_t = (model_t - model_prev_0)
- x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
- return x_t, model_t
-
-
- def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
- method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
- atol=0.0078, rtol=0.05, corrector=False, timesteps=None,
- ):
- t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
- t_T = self.noise_schedule.T if t_start is None else t_start
- device = x.device
- if method == 'multistep':
- if timesteps is None:
- timesteps = get_time_steps(self.noise_schedule, skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
- assert steps >= order, "UniPC order must be < sampling steps"
- assert timesteps.shape[0] - 1 == steps
- with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=shared.console) as progress:
- task = progress.add_task(description="Initializing", total=steps)
- t = time.time()
- with devices.inference_context():
- vec_t = timesteps[0].expand((x.shape[0]))
- model_prev_list = [self.model_fn(x, vec_t)]
- t_prev_list = [vec_t]
- # Init the first `order` values by lower order multistep DPM-Solver.
- for init_order in range(1, order):
- vec_t = timesteps[init_order].expand(x.shape[0])
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
- if model_x is None:
- model_x = self.model_fn(x, vec_t)
- if self.after_update is not None:
- self.after_update(x, model_x)
- model_prev_list.append(model_x)
- t_prev_list.append(vec_t)
- progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * init_order / (time.time() - t), 2)}it/s")
- # for step in trange(order, steps + 1):
- for step in range(order, steps + 1):
- vec_t = timesteps[step].expand(x.shape[0])
- if lower_order_final:
- step_order = min(order, steps + 1 - step)
- else:
- step_order = order
- if step == steps:
- use_corrector = False
- else:
- use_corrector = True
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
- if self.after_update is not None:
- self.after_update(x, model_x)
- for i in range(order - 1):
- t_prev_list[i] = t_prev_list[i + 1]
- model_prev_list[i] = model_prev_list[i + 1]
- t_prev_list[-1] = vec_t
- # We do not need to evaluate the final model value.
- if step < steps:
- if model_x is None:
- model_x = self.model_fn(x, vec_t)
- model_prev_list[-1] = model_x
- progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * step / (time.time() - t), 2)}it/s")
- else:
- raise NotImplementedError
- if denoise_to_zero:
- x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
- return x
-
-
-#############################################################
-# other utility functions
-#############################################################
-
-def interpolate_fn(x, xp, yp):
- """
- A piecewise linear function y = f(x), using xp and yp as keypoints.
- We implement f(x) in a differentiable way (i.e. applicable for autograd).
- The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
-
- Args:
- x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
- xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
- yp: PyTorch tensor with shape [C, K].
- Returns:
- The function values f(x), with shape [N, C].
- """
- N, K = x.shape[0], xp.shape[1]
- all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
- sorted_all_x, x_indices = torch.sort(all_x, dim=2)
- x_idx = torch.argmin(x_indices, dim=2)
- cand_start_idx = x_idx - 1
- start_idx = torch.where(
- torch.eq(x_idx, 0),
- torch.tensor(1, device=x.device),
- torch.where(
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
- ),
- )
- end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
- start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
- end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
- start_idx2 = torch.where(
- torch.eq(x_idx, 0),
- torch.tensor(0, device=x.device),
- torch.where(
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
- ),
- )
- y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
- start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
- end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
- cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
- return cand
-
-
-def expand_dims(v, dims):
- """
- Expand the tensor `v` to the dim `dims`.
-
- Args:
- `v`: a PyTorch tensor with shape [N].
- `dim`: a `int`.
- Returns:
- a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
- """
- return v[(...,) + (None,)*(dims - 1)]
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index 05f007ba0..1fbf95768 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -2,40 +2,10 @@ import math
import gradio as gr
from modules import images, scripts_manager
from modules.processing import process_images
-from modules.shared import opts, state, log
+from modules.shared import opts, log
import modules.sd_samplers
-def draw_xy_grid(xs, ys, x_label, y_label, cell):
- res = []
-
- ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
- hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
-
- first_processed = None
-
- state.job_count = len(xs) * len(ys)
-
- for iy, y in enumerate(ys):
- for ix, x in enumerate(xs):
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
-
- processed, _t = cell(x, y)
- if first_processed is None:
- first_processed = processed
-
- res.append(processed.images[0])
-
- if images.check_grid_size(res):
- grid = images.image_grid(res, rows=len(ys))
- grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
- first_processed.images = [grid]
- else:
- first_processed.images = res
-
- return first_processed
-
-
class Script(scripts_manager.Script):
def title(self):
return "Prompt matrix"
diff --git a/webui.py b/webui.py
index 12d1b285a..236b220f2 100644
--- a/webui.py
+++ b/webui.py
@@ -31,7 +31,7 @@ import modules.upscaler
import modules.upscaler_simple
import modules.extra_networks
import modules.ui_extra_networks
-import modules.textual_inversion.textual_inversion
+import modules.textual_inversion
import modules.script_callbacks
import modules.api.middleware