remove dead code

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4039/head
Vladimir Mandic 2025-07-05 16:47:25 -04:00
parent 748ec564ad
commit a5b77b8ee2
26 changed files with 10 additions and 1536 deletions

View File

@ -33,7 +33,6 @@ ignore-paths=/usr/lib/.*$,
modules/taesd,
modules/teacache,
modules/todo,
modules/unipc,
pipelines/flex2,
pipelines/hidream,
pipelines/meissonic,

View File

@ -159,13 +159,6 @@ def get_gpu_info():
return { 'error': ex }
def extract_device_id(args, name): # pylint: disable=redefined-outer-name
for x in range(len(args)):
if name in args[x]:
return args[x + 1]
return None
def get_cuda_device_string():
from modules.shared import cmd_opts
if backend == 'ipex':
@ -196,13 +189,6 @@ def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task): # pylint: disable=unused-argument
# if task in cmd_opts.use_cpu:
# log.debug(f'Forcing CPU for task: {task}')
# return cpu
return get_optimal_device()
def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
def get_stats():
mem_dict = memstats.memory_stats()
@ -583,14 +569,6 @@ def set_cuda_params():
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
def cond_cast_unet(tensor):
return tensor.to(dtype_unet) if unet_needs_upcast else tensor
def cond_cast_float(tensor):
return tensor.float() if unet_needs_upcast else tensor
def randn(seed, shape=None):
torch.manual_seed(seed)
if backend == 'ipex':

View File

@ -15,12 +15,6 @@ def install(suppress=[]):
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
def print_error_explanation(message):
lines = message.strip().split("\n")
for line in lines:
log.error(line)
def display(e: Exception, task: str, suppress=[]):
log.error(f"{task or 'error'}: {type(e).__name__}")
console = get_console()

View File

@ -138,15 +138,6 @@ def should_skip(param):
return skip
def bind_buttons(buttons, image_component, send_generate_info):
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
for tabname, button in buttons.items():
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=image_component, source_tabname=source_tabname)
register_paste_params_button(bindings)
def register_paste_params_button(binding: ParamBinding):
registered_param_bindings.append(binding)
@ -194,17 +185,6 @@ def send_image(x):
return image
def send_image_and_dimensions(x):
image = x if isinstance(x, Image.Image) else image_from_url_text(x)
if shared.opts.send_size and isinstance(image, Image.Image):
w = image.width
h = image.height
else:
w = gr.update()
h = gr.update()
return image, w, h
def create_override_settings_dict(text_pairs):
res = {}
params = {}

View File

@ -28,40 +28,6 @@ def unquote(text):
return text
# disabled by default can be enabled if needed
def check_lora(params):
try:
from modules.lora import lora_load
from modules.errors import log # pylint: disable=redefined-outer-name
except Exception:
return
loras = [s.strip() for s in params.get('LoRA hashes', '').split(',')]
found = []
missing = []
for l in loras:
lora = lora_load.available_network_hash_lookup.get(l, None)
if lora is not None:
found.append(lora.name)
else:
missing.append(l)
loras = [s.strip() for s in params.get('LoRA networks', '').split(',')]
for l in loras:
lora = lora_load.available_network_aliases.get(l, None)
if lora is not None:
found.append(lora.name)
else:
missing.append(l)
# networks.available_network_aliases.get(name, None)
loras = re_lora.findall(params.get('Prompt', ''))
for l in loras:
lora = lora_load.available_network_aliases.get(l, None)
if lora is not None:
found.append(lora.name)
else:
missing.append(l)
log.debug(f'LoRA: found={list(set(found))} missing={list(set(missing))}')
def parse(infotext):
if not isinstance(infotext, str):
return {}
@ -115,7 +81,6 @@ def parse(infotext):
params[key] = val
debug(f'Param parsed: type={type(params[key])} "{key}"={params[key]} raw="{val}"')
# check_lora(params)
return params

View File

@ -84,10 +84,6 @@ def reset_stats():
pass
def memory_cache():
return mem
def ram_stats():
try:
process = psutil.Process(os.getpid())

View File

@ -160,11 +160,8 @@ def list_models():
def update_model_hashes():
txt = []
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
# shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
for ckpt in lst:
ckpt.hash = model_hash(ckpt.filename)
# txt.append(f'Calculated short hash: <b>{ckpt.title}</b> {ckpt.hash}')
# txt.append(f'Updated short hashes for <b>{len(lst)}</b> out of <b>{len(checkpoints_list)}</b> models')
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
for ckpt in lst:
@ -214,15 +211,6 @@ def get_closet_checkpoint_match(s: str) -> CheckpointInfo:
checkpoint_info = CheckpointInfo(s)
return checkpoint_info
# reference search
"""
found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path']))
if found and len(found) == 1:
checkpoint_info = CheckpointInfo(found[0]['path']) # create a virutal model info
checkpoint_info.type = 'huggingface'
return checkpoint_info
"""
# huggingface search
if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1:
modelloader.hf_login()
@ -251,13 +239,10 @@ def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
# t0 = time.time()
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
shorthash = m.hexdigest()[0:8]
# t1 = time.time()
# shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
return shorthash
except FileNotFoundError:
return 'NOFILE'
@ -280,14 +265,11 @@ def select_checkpoint(op='model'):
shared.log.info(" or use --ckpt-dir <path-to-folder> to specify folder with sd models")
shared.log.info(" or use --ckpt <path-to-checkpoint> to force using specific model")
return None
# checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0':
shared.log.info(f'Load {op}: search="{model_checkpoint}" not found')
else:
shared.log.info("Selecting first available checkpoint")
# shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
# shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
else:
shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
return checkpoint_info
@ -367,8 +349,6 @@ def read_metadata_from_safetensors(filename):
t1 = time.time()
global sd_metadata_timer # pylint: disable=global-statement
sd_metadata_timer += (t1 - t0)
# except Exception as e:
# shared.log.error(f"Error reading metadata from: {filename} {e}")
return res

View File

@ -620,7 +620,7 @@ def load_diffuser(checkpoint_info=None, timer=None, op='model', revision=None):
if debug_load:
shared.log.trace(f'Model components: {list(get_signature(sd_model).values())}')
from modules.textual_inversion import textual_inversion
from modules import textual_inversion
sd_model.embedding_db = textual_inversion.EmbeddingDatabase()
sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)

View File

@ -319,9 +319,6 @@ class DiffusionSampler:
return
self.sampler = sampler
if name == 'DC Solver':
if not hasattr(self.sampler, 'dc_ratios'):
pass
# shared.log.debug_log(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}')
self.sampler.name = name

View File

@ -1,6 +1,5 @@
import os
import glob
from copy import deepcopy
import torch
from modules import shared, errors, paths, devices, sd_models, sd_detect
@ -12,35 +11,13 @@ loaded_vae_file = None
checkpoint_info = None
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
unspecified = object()
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info # pylint: disable=global-statement
if checkpoint_info != model.sd_checkpoint_info:
assert not loaded_vae_file, "Trying to store non-base VAE!"
base_vae = deepcopy(model.first_stage_model.state_dict())
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info # pylint: disable=global-statement
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global loaded_vae_file # pylint: disable=global-statement
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
shared.log.info("Restoring base VAE")
_load_vae_dict(model, base_vae)
loaded_vae_file = None
delete_base_vae()
def load_vae_dict(filename):
vae_ckpt = sd_models.read_state_dict(filename, what='vae')
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict_1
def get_filename(filepath):
@ -114,35 +91,6 @@ def resolve_vae(checkpoint_file):
return None, None
def load_vae_dict(filename):
vae_ckpt = sd_models.read_state_dict(filename, what='vae')
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict_1
def load_vae(model, vae_file=None, vae_source="unknown-source"):
global loaded_vae_file # pylint: disable=global-statement
if vae_file:
try:
if not os.path.isfile(vae_file):
shared.log.error(f"VAE not found: model={vae_file} source={vae_source}")
return
store_base_vae(model)
vae_dict_1 = load_vae_dict(vae_file)
_load_vae_dict(model, vae_dict_1)
except Exception as e:
shared.log.error(f"Load VAE failed: model={vae_file} source={vae_source} {e}")
if debug:
errors.display(e, 'VAE')
restore_base_vae(model)
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
def apply_vae_config(model_file, vae_file, sd_model):
def get_vae_config():
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(model_file))[0] + '_vae.json')
@ -219,20 +167,6 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"):
return None
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
def clear_loaded_vae():
global loaded_vae_file # pylint: disable=global-statement
loaded_vae_file = None
unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified):
if not sd_model:
sd_model = shared.sd_model

View File

@ -3,9 +3,7 @@ import os
import time
import torch
import safetensors.torch
from PIL import Image
from modules import shared, devices, errors
from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed
from modules.files_cache import directory_files, directory_mtime, extension_filter
@ -241,9 +239,6 @@ class EmbeddingDatabase:
def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
def clear_embedding_dirs(self):
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
if hasattr(model, 'cond_stage_model'):
@ -306,45 +301,6 @@ class EmbeddingDatabase:
errors.display(e, f'Load embedding: name="{embedding.name}" file="{embedding.filename}"')
return
def load_from_file(self, path, filename):
name, ext = os.path.splitext(filename)
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
if '.preview' in filename.lower():
return
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
else:
data = extract_image_data_embed(embed_image)
if not data: # if data is None, means this is not an embeding, just a preview image
return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
data = safetensors.torch.load_file(path, device="cpu")
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
if len(data.keys()) != 1:
self.skipped_embeddings[name] = Embedding(None, name=name, filename=path)
return
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding")
def load_from_dir(self, embdir):
if not shared.sd_loaded:
shared.log.info('Skipping embeddings load: model not loaded')
@ -387,14 +343,3 @@ class EmbeddingDatabase:
self.previously_displayed_embeddings = displayed_embeddings
t1 = time.time()
shared.log.info(f"Network load: type=embeddings loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding, len(ids)
return None, None

View File

@ -1,213 +0,0 @@
import base64
import json
import zlib
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, torch.Tensor):
return {'TORCHTENSOR': o.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, o)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
def object_hook(self, d): # pylint: disable=E0202
if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embedding_to_b64(data):
d = json.dumps(data, cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def embedding_from_b64(data):
d = base64.b64decode(data)
return json.loads(d, cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed % 255
def xor_block(block):
blk = lcg()
randblock = np.array([next(blk) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
def style_block(block, sequence):
im = Image.new('RGB', (block.shape[1], block.shape[0]))
draw = ImageDraw.Draw(im)
i = 0
for x in range(-6, im.size[0], 8):
for yi, y in enumerate(range(-6, im.size[1], 8)):
offset = 0
if yi % 2 == 0:
offset = 4
shade = sequence[i % len(sequence)]
i += 1
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
fg = np.array(im).astype(np.uint8) & 0xF0
return block ^ fg
def insert_image_data_embed(image, data):
d = 3
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
data_np_high = data_np_ >> 4
data_np_low = data_np_ & 0x0F
h = image.size[1]
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
data_np_low = np.resize(data_np_low, next_size)
data_np_low = data_np_low.reshape((h, -1, d))
data_np_high = np.resize(data_np_high, next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
data_np_low = style_block(data_np_low, sequence=edge_style)
data_np_low = xor_block(data_np_low)
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
data_np_high = xor_block(data_np_high)
im_low = Image.fromarray(data_np_low, mode='RGB')
im_high = Image.fromarray(data_np_high, mode='RGB')
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
background.paste(im_low, (0, 0))
background.paste(image, (im_low.size[0]+1, 0))
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
return background
def crop_black(img, tol=0):
mask = (img > tol).all(2)
mask0, mask1 = mask.any(0), mask.any(1)
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end, col_start:col_end]
def extract_image_data_embed(image):
d = 3
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F # pylint: disable=E1121
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
if black_cols[0].shape[0] < 2:
return None
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
data_block_lower = xor_block(data_block_lower)
data_block_upper = xor_block(data_block_upper)
data_block = (data_block_upper << 4) | (data_block_lower)
data_block = data_block.flatten().tobytes()
data = zlib.decompress(data_block)
return json.loads(data, cls=EmbeddingDecoder)
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
from math import cos
image = srcimage.copy()
fontsize = 32
if textfont is None:
textfont = opts.font or 'javascript/notosans-nerdfont-regular.ttf'
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
for y in range(image.size[1]):
mag = 1-cos(y/image.size[1]*factor)
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(textfont, fontsize)
padding = 10
_, _, w, _h = draw.textbbox((0, 0), title, font=font)
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
font = ImageFont.truetype(textfont, fontsize)
_, _, w, _h = draw.textbbox((0, 0), title, font=font)
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
_, _, w, _h = draw.textbbox((0, 0), footerLeft, font=font)
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, _h = draw.textbbox((0, 0), footerMid, font=font)
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, _h = draw.textbbox((0, 0), footerRight, font=font)
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
return image
if __name__ == '__main__':
testEmbed = Image.open('test_embedding.png')
test_data = extract_image_data_embed(testEmbed)
assert test_data is not None
test_data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
assert test_data is not None
new_image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
cap_image = caption_image_overlay(new_image, 'title', 'footerLeft', 'footerMid', 'footerRight')
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
embedded_image = insert_image_data_embed(cap_image, test_embed)
retrived_embed = extract_image_data_embed(embedded_image)
assert str(retrived_embed) == str(test_embed)
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
assert embedded_image == embedded_image2
g = lcg()
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
204, 86, 73, 222, 44, 198, 118, 240, 97]
assert shared_random == reference_random
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
assert 12731374 == hunna_kay_random_sum

View File

@ -67,10 +67,6 @@ def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument
pass
def ordered_ui_categories():
return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
def create_ui(startup_timer = None):
if startup_timer is None:
timer.startup = timer.Timer()

View File

@ -323,19 +323,6 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args = No
return refresh_button
def create_browse_button(browse_component, elem_id):
def browse(folder):
# import subprocess
if folder is not None:
return gr.update(value = folder)
return gr.update()
browse_button = ui_components.ToolButton(value=ui_symbols.folder, elem_id=elem_id)
browse_button.click(fn=browse, _js="async () => await browseFolder()", inputs=[browse_component], outputs=[browse_component])
# browse_button.click(fn=browse, inputs=[browse_component], outputs=[browse_component])
return browse_button
def create_override_inputs(tab): # pylint: disable=unused-argument
with gr.Row(elem_id=f"{tab}_override_settings_row"):
override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True)

View File

@ -69,10 +69,6 @@ def list_extensions():
debug(f'Extension installed without index: {entry}')
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
def apply_changes(disable_list, update_list, disable_all):
if shared.cmd_opts.disable_extension_access:
shared.log.error('Extension: apply changes disallowed because public access is enabled and insecure is not specified')
@ -126,18 +122,6 @@ def check_updates(_id_task, disable_list, search_text, sort_column):
return create_html(search_text, sort_column), "Extension update complete | Restart required"
def make_commit_link(commit_hash, remote, text=None):
if text is None:
text = commit_hash[:8]
if remote.startswith("https://github.com/"):
if remote.endswith(".git"):
remote = remote[:-4]
href = remote + "/commit/" + commit_hash
return f'<a href="{href}" target="_blank">{text}</a>'
else:
return text
def normalize_git_url(url):
if url is None:
return ""

View File

@ -162,9 +162,6 @@ class ExtraNetworksPage:
preview = f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
return preview
def is_empty(self, folder):
return any(files_cache.list_files(folder, ext_filter=['.ckpt', '.safetensors', '.pt', '.json']))
def create_thumb(self):
debug(f'EN create-thumb: {self.name}')
created = 0

View File

@ -1,7 +1,7 @@
import json
import os
from modules import shared, sd_models, ui_extra_networks, files_cache
from modules.textual_inversion.textual_inversion import Embedding
from modules.textual_inversion import Embedding
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):

View File

@ -115,20 +115,5 @@ def reload_javascript():
gradio.routes.templates.TemplateResponse = template_response
def setup_ui_api(app):
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from typing import List
class QuicksettingsHint(BaseModel): # pylint: disable=too-few-public-methods
name: str = Field(title="Name of the quicksettings field")
label: str = Field(title="Label of the quicksettings field")
def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in shared.opts.data_labels.items()]
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse

View File

@ -99,20 +99,6 @@ def create_interrogate_button(tab: str, inputs: list = None, outputs: str = None
return button_interrogate
def create_interrogate_buttons(tab): # legacy function
button_interrogate = gr.Button(ui_symbols.int_clip, elem_id=f"{tab}_interrogate", elem_classes=['interrogate-clip'])
button_deepbooru = gr.Button(ui_symbols.int_blip, elem_id=f"{tab}_deepbooru", elem_classes=['interrogate-blip'])
return button_interrogate, button_deepbooru
def create_sampler_inputs(tab, accordion=True):
with gr.Accordion(open=False, label="Sampler", elem_id=f"{tab}_sampler", elem_classes=["small-accordion"]) if accordion else gr.Group():
with gr.Row(elem_id=f"{tab}_row_sampler"):
sd_samplers.set_samplers()
steps, sampler_index = create_sampler_and_steps_selection(sd_samplers.samplers, tab)
return steps, sampler_index
def create_batch_inputs(tab, accordion=True):
with gr.Accordion(open=False, label="Batch", elem_id=f"{tab}_batch", elem_classes=["small-accordion"]) if accordion else gr.Group():
with gr.Row(elem_id=f"{tab}_row_batch"):

View File

@ -48,10 +48,6 @@ def get_value_for_setting(key):
return gr.update(value=value, **args)
def ordered_ui_categories():
return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
def create_setting_component(key, is_quicksettings=False):
def fun():
return shared.opts.data[key] if key in shared.opts.data else shared.opts.data_labels[key].default
@ -88,7 +84,6 @@ def create_setting_component(key, is_quicksettings=False):
elif info.folder is not None:
with gr.Row():
res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
# ui_common.create_browse_button(res, f"folder_{key}")
else:
try:
res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)

View File

@ -1,19 +1,9 @@
import gradio as gr
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing, processing_vae, devices, images
from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing_vae, images
from modules.ui_components import ToolButton # pylint: disable=unused-import
def calc_resolution_hires(width, height, hr_scale, hr_resize_x, hr_resize_y, hr_upscaler):
if hr_upscaler == "None":
return "Hires resize: None"
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
p.init_hr()
with devices.autocast():
p.init([""], [0], [0])
return f"Hires resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def create_ui():
shared.log.debug('UI initialize: txt2img')
import modules.txt2img # pylint: disable=redefined-outer-name

View File

@ -1 +0,0 @@
from .sampler import UniPCSampler

View File

@ -1,191 +0,0 @@
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC, get_time_steps
from modules import shared, devices
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.before_sample = None
self.after_sample = None
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
self.noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
# persist steps so we can eventually find denoising strength
self.inflated_steps = ddim_num_steps
@devices.inference_context()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
if noise is None:
noise = torch.randn_like(x0)
# first time we have all the info to get the real parameters from the ui
# value from the hires steps slider:
num_inference_steps = t[0] + 1
num_inference_steps / self.inflated_steps
self.denoise_steps = max(num_inference_steps, shared.opts.schedulers_solver_order)
max(self.inflated_steps - self.denoise_steps, 0)
# actual number of steps we'll run
all_timesteps = get_time_steps(
self.noise_schedule,
shared.opts.uni_pc_skip_type,
self.noise_schedule.T,
1./self.noise_schedule.total_N,
self.inflated_steps+1,
t.device,
)
# the rest of the timesteps will be used for denoising
self.timesteps = all_timesteps[-(self.denoise_steps+1):]
latent_timestep = (
( # get the timestep of our first denoise step
self.timesteps[:1]
# multiply by number of alphas to get int index
* self.noise_schedule.total_N
).int() - 1 # minus one for 0-indexed
).repeat(x0.shape[0])
alphas_cumprod = self.alphas_cumprod
sqrt_alpha_prod = alphas_cumprod[latent_timestep] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(x0.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[latent_timestep]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(x0.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
return (sqrt_alpha_prod * x0 + sqrt_one_minus_alpha_prod * noise)
def decode(self, x_latent, conditioning, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
# same as in .sample(), i guess
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
self.noise_schedule,
model_type=model_type,
guidance_type="classifier-free",
#condition=conditioning,
#unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
self.uni_pc = UniPC(
model_fn,
self.noise_schedule,
predict_x0=True,
thresholding=False,
variant=shared.opts.uni_pc_variant,
condition=conditioning,
unconditional_condition=unconditional_conditioning,
before_sample=self.before_sample,
after_sample=self.after_sample,
after_update=self.after_update,
)
return self.uni_pc.sample(
x_latent,
steps=self.denoise_steps,
skip_type=shared.opts.uni_pc_skip_type,
method="multistep",
order=shared.opts.schedulers_solver_order,
lower_order_final=shared.opts.schedulers_use_loworder,
denoise_to_zero=True,
timesteps=self.timesteps,
)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):
self.before_sample = before_sample
self.after_sample = after_sample
self.after_update = after_update
@devices.inference_context()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
shared.log.warning(f"UniPC: got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
shared.log.warning(f"UniPC: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
shared.log.warning(f"UniPC: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
# SD 1.X is "noise", SD 2.X is "v"
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
self.noise_schedule,
model_type=model_type,
guidance_type="classifier-free",
#condition=conditioning,
#unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, self.noise_schedule, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x.to(device), None

View File

@ -1,779 +0,0 @@
import torch
import math
import time
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from modules import shared, devices
class NoiseScheduleVP:
def __init__(
self,
schedule='discrete',
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
self.schedule = schedule
if schedule == 'discrete':
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
self.log_alpha_array = log_alphas.reshape((1, -1,))
else:
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
self.cosine_s = 0.008
self.cosine_beta_max = 999.
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
self.schedule = schedule
if schedule == 'cosine':
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self.T = 0.9946
else:
self.T = 1.
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == 'discrete':
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
elif self.schedule == 'linear':
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == 'cosine':
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == 'linear':
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == 'discrete':
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
return t.reshape((-1,))
else:
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
t = t_fn(log_alpha)
return t
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs=None,
guidance_type="uncond",
#condition=None,
#unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs=None,
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
model_kwargs = model_kwargs or {}
classifier_kwargs = classifier_kwargs or {}
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == 'discrete':
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, None, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return -expand_dims(sigma_t, dims) * output
def cond_grad_fn(x, t_input, condition):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous, condition, unconditional_condition):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input, condition)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
elif guidance_type == "classifier-free":
if guidance_scale == 1. or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
c_in = {}
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([
unconditional_condition[k][i],
condition[k][i]]) for i in range(len(condition[k]))]
else:
c_in[k] = torch.cat([
unconditional_condition[k],
condition[k]])
elif isinstance(condition, list):
c_in = []
assert isinstance(unconditional_condition, list)
for i in range(len(condition)):
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
else:
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
def get_time_steps(noise_schedule, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
"""
if skip_type == 'logSNR':
lambda_T = noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == 'time_uniform':
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
t_order = 2
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t
else:
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
class UniPC:
def __init__(
self,
model_fn,
noise_schedule,
predict_x0=True,
thresholding=False,
max_val=1.,
variant='bh1',
condition=None,
unconditional_condition=None,
before_sample=None,
after_sample=None,
after_update=None
):
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
self.model_fn_ = model_fn
self.noise_schedule = noise_schedule
self.variant = variant
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.condition = condition
self.unconditional_condition = unconditional_condition
self.before_sample = before_sample
self.after_sample = after_sample
self.after_update = after_update
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def model(self, x, t):
cond = self.condition
uncond = self.unconditional_condition
if self.before_sample is not None:
x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
res = self.model_fn_(x, t, cond, uncond)
if self.after_sample is not None:
x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
if isinstance(res, tuple):
# (None, pred_x0)
res = res[1]
return res
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with thresholding).
"""
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3,] * (K - 2) + [2, 1]
elif steps % 3 == 1:
orders = [3,] * (K - 1) + [1]
else:
orders = [3,] * (K - 1) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [2,] * K
else:
K = steps // 2 + 1
orders = [2,] * (K - 1) + [1]
elif order == 1:
K = steps
orders = [1,] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == 'logSNR':
# To reproduce the results in DPM-Solver paper
timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, K, device)
else:
timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
if len(t.shape) == 0:
t = t.view(-1)
if 'bh' in self.variant:
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
else:
assert self.variant == 'vary_coeff'
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
ns = self.noise_schedule
assert order <= len(model_prev_list)
# first compute rks
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
lambda_t = ns.marginal_lambda(t)
model_prev_0 = model_prev_list[-1]
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
log_alpha_t = ns.marginal_log_mean_coeff(t)
alpha_t = torch.exp(log_alpha_t)
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = ns.marginal_lambda(t_prev_i)
rk = (lambda_prev_i - lambda_prev_0) / h
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.)
rks = torch.tensor(rks, device=x.device)
K = len(rks)
# build C matrix
C = []
col = torch.ones_like(rks)
for k in range(1, K + 1):
C.append(col)
col = col * rks / (k + 1)
C = torch.stack(C, dim=1)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
C_inv_p = torch.linalg.inv(C[:-1, :-1])
A_p = C_inv_p
if use_corrector:
C_inv = torch.linalg.inv(C)
A_c = C_inv
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh)
h_phi_ks = []
factorial_k = 1
h_phi_k = h_phi_1
for k in range(1, K + 2):
h_phi_ks.append(h_phi_k)
h_phi_k = h_phi_k / hh - 1 / factorial_k
factorial_k *= (k + 1)
model_t = None
if self.predict_x0:
x_t_ = (
sigma_t / sigma_prev_0 * x
- alpha_t * h_phi_1 * model_prev_0
)
# now predictor
x_t = x_t_
if len(D1s) > 0:
# compute the residuals for predictor
for k in range(K - 1):
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
# now corrector
if use_corrector:
model_t = self.model_fn(x_t, t)
D1_t = (model_t - model_prev_0)
x_t = x_t_
k = 0
for k in range(K - 1):
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
else:
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
x_t_ = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * h_phi_1) * model_prev_0
)
# now predictor
x_t = x_t_
if len(D1s) > 0:
# compute the residuals for predictor
for k in range(K - 1):
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
# now corrector
if use_corrector:
model_t = self.model_fn(x_t, t)
D1_t = (model_t - model_prev_0)
x_t = x_t_
k = 0
for k in range(K - 1):
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
return x_t, model_t
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
ns = self.noise_schedule
assert order <= len(model_prev_list)
dims = x.dim()
# first compute rks
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
lambda_t = ns.marginal_lambda(t)
model_prev_0 = model_prev_list[-1]
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
alpha_t = torch.exp(log_alpha_t)
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = ns.marginal_lambda(t_prev_i)
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.)
rks = torch.tensor(rks, device=x.device)
R = []
b = []
hh = -h[0] if self.predict_x0 else h[0]
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.variant == 'bh1':
B_h = hh
elif self.variant == 'bh2':
B_h = torch.expm1(hh)
else:
raise NotImplementedError
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
# now predictor
use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if use_corrector:
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], device=b.device)
else:
rhos_c = torch.linalg.solve(R, b)
model_t = None
if self.predict_x0:
x_t_ = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
)
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
)
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
return x_t, model_t
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, timesteps=None,
):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == 'multistep':
if timesteps is None:
timesteps = get_time_steps(self.noise_schedule, skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert steps >= order, "UniPC order must be < sampling steps"
assert timesteps.shape[0] - 1 == steps
with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=shared.console) as progress:
task = progress.add_task(description="Initializing", total=steps)
t = time.time()
with devices.inference_context():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None:
model_x = self.model_fn(x, vec_t)
if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * init_order / (time.time() - t), 2)}it/s")
# for step in trange(order, steps + 1):
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
if step == steps:
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * step / (time.time() - t), 2)}it/s")
else:
raise NotImplementedError
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,)*(dims - 1)]

View File

@ -2,40 +2,10 @@ import math
import gradio as gr
from modules import images, scripts_manager
from modules.processing import process_images
from modules.shared import opts, state, log
from modules.shared import opts, log
import modules.sd_samplers
def draw_xy_grid(xs, ys, x_label, y_label, cell):
res = []
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
first_processed = None
state.job_count = len(xs) * len(ys)
for iy, y in enumerate(ys):
for ix, x in enumerate(xs):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed, _t = cell(x, y)
if first_processed is None:
first_processed = processed
res.append(processed.images[0])
if images.check_grid_size(res):
grid = images.image_grid(res, rows=len(ys))
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
first_processed.images = [grid]
else:
first_processed.images = res
return first_processed
class Script(scripts_manager.Script):
def title(self):
return "Prompt matrix"

View File

@ -31,7 +31,7 @@ import modules.upscaler
import modules.upscaler_simple
import modules.extra_networks
import modules.ui_extra_networks
import modules.textual_inversion.textual_inversion
import modules.textual_inversion
import modules.script_callbacks
import modules.api.middleware