mirror of https://github.com/vladmandic/automatic
parent
748ec564ad
commit
a5b77b8ee2
|
|
@ -33,7 +33,6 @@ ignore-paths=/usr/lib/.*$,
|
|||
modules/taesd,
|
||||
modules/teacache,
|
||||
modules/todo,
|
||||
modules/unipc,
|
||||
pipelines/flex2,
|
||||
pipelines/hidream,
|
||||
pipelines/meissonic,
|
||||
|
|
|
|||
|
|
@ -159,13 +159,6 @@ def get_gpu_info():
|
|||
return { 'error': ex }
|
||||
|
||||
|
||||
def extract_device_id(args, name): # pylint: disable=redefined-outer-name
|
||||
for x in range(len(args)):
|
||||
if name in args[x]:
|
||||
return args[x + 1]
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules.shared import cmd_opts
|
||||
if backend == 'ipex':
|
||||
|
|
@ -196,13 +189,6 @@ def get_optimal_device():
|
|||
return torch.device(get_optimal_device_name())
|
||||
|
||||
|
||||
def get_device_for(task): # pylint: disable=unused-argument
|
||||
# if task in cmd_opts.use_cpu:
|
||||
# log.debug(f'Forcing CPU for task: {task}')
|
||||
# return cpu
|
||||
return get_optimal_device()
|
||||
|
||||
|
||||
def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
|
||||
def get_stats():
|
||||
mem_dict = memstats.memory_stats()
|
||||
|
|
@ -583,14 +569,6 @@ def set_cuda_params():
|
|||
log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} optimization="{opts.cross_attention_optimization}"')
|
||||
|
||||
|
||||
def cond_cast_unet(tensor):
|
||||
return tensor.to(dtype_unet) if unet_needs_upcast else tensor
|
||||
|
||||
|
||||
def cond_cast_float(tensor):
|
||||
return tensor.float() if unet_needs_upcast else tensor
|
||||
|
||||
|
||||
def randn(seed, shape=None):
|
||||
torch.manual_seed(seed)
|
||||
if backend == 'ipex':
|
||||
|
|
|
|||
|
|
@ -15,12 +15,6 @@ def install(suppress=[]):
|
|||
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
lines = message.strip().split("\n")
|
||||
for line in lines:
|
||||
log.error(line)
|
||||
|
||||
|
||||
def display(e: Exception, task: str, suppress=[]):
|
||||
log.error(f"{task or 'error'}: {type(e).__name__}")
|
||||
console = get_console()
|
||||
|
|
|
|||
|
|
@ -138,15 +138,6 @@ def should_skip(param):
|
|||
return skip
|
||||
|
||||
|
||||
def bind_buttons(buttons, image_component, send_generate_info):
|
||||
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
||||
for tabname, button in buttons.items():
|
||||
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
||||
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
||||
bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=image_component, source_tabname=source_tabname)
|
||||
register_paste_params_button(bindings)
|
||||
|
||||
|
||||
def register_paste_params_button(binding: ParamBinding):
|
||||
registered_param_bindings.append(binding)
|
||||
|
||||
|
|
@ -194,17 +185,6 @@ def send_image(x):
|
|||
return image
|
||||
|
||||
|
||||
def send_image_and_dimensions(x):
|
||||
image = x if isinstance(x, Image.Image) else image_from_url_text(x)
|
||||
if shared.opts.send_size and isinstance(image, Image.Image):
|
||||
w = image.width
|
||||
h = image.height
|
||||
else:
|
||||
w = gr.update()
|
||||
h = gr.update()
|
||||
return image, w, h
|
||||
|
||||
|
||||
def create_override_settings_dict(text_pairs):
|
||||
res = {}
|
||||
params = {}
|
||||
|
|
|
|||
|
|
@ -28,40 +28,6 @@ def unquote(text):
|
|||
return text
|
||||
|
||||
|
||||
# disabled by default can be enabled if needed
|
||||
def check_lora(params):
|
||||
try:
|
||||
from modules.lora import lora_load
|
||||
from modules.errors import log # pylint: disable=redefined-outer-name
|
||||
except Exception:
|
||||
return
|
||||
loras = [s.strip() for s in params.get('LoRA hashes', '').split(',')]
|
||||
found = []
|
||||
missing = []
|
||||
for l in loras:
|
||||
lora = lora_load.available_network_hash_lookup.get(l, None)
|
||||
if lora is not None:
|
||||
found.append(lora.name)
|
||||
else:
|
||||
missing.append(l)
|
||||
loras = [s.strip() for s in params.get('LoRA networks', '').split(',')]
|
||||
for l in loras:
|
||||
lora = lora_load.available_network_aliases.get(l, None)
|
||||
if lora is not None:
|
||||
found.append(lora.name)
|
||||
else:
|
||||
missing.append(l)
|
||||
# networks.available_network_aliases.get(name, None)
|
||||
loras = re_lora.findall(params.get('Prompt', ''))
|
||||
for l in loras:
|
||||
lora = lora_load.available_network_aliases.get(l, None)
|
||||
if lora is not None:
|
||||
found.append(lora.name)
|
||||
else:
|
||||
missing.append(l)
|
||||
log.debug(f'LoRA: found={list(set(found))} missing={list(set(missing))}')
|
||||
|
||||
|
||||
def parse(infotext):
|
||||
if not isinstance(infotext, str):
|
||||
return {}
|
||||
|
|
@ -115,7 +81,6 @@ def parse(infotext):
|
|||
params[key] = val
|
||||
debug(f'Param parsed: type={type(params[key])} "{key}"={params[key]} raw="{val}"')
|
||||
|
||||
# check_lora(params)
|
||||
return params
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -84,10 +84,6 @@ def reset_stats():
|
|||
pass
|
||||
|
||||
|
||||
def memory_cache():
|
||||
return mem
|
||||
|
||||
|
||||
def ram_stats():
|
||||
try:
|
||||
process = psutil.Process(os.getpid())
|
||||
|
|
|
|||
|
|
@ -160,11 +160,8 @@ def list_models():
|
|||
def update_model_hashes():
|
||||
txt = []
|
||||
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
|
||||
# shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
|
||||
for ckpt in lst:
|
||||
ckpt.hash = model_hash(ckpt.filename)
|
||||
# txt.append(f'Calculated short hash: <b>{ckpt.title}</b> {ckpt.hash}')
|
||||
# txt.append(f'Updated short hashes for <b>{len(lst)}</b> out of <b>{len(checkpoints_list)}</b> models')
|
||||
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
|
||||
shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
|
||||
for ckpt in lst:
|
||||
|
|
@ -214,15 +211,6 @@ def get_closet_checkpoint_match(s: str) -> CheckpointInfo:
|
|||
checkpoint_info = CheckpointInfo(s)
|
||||
return checkpoint_info
|
||||
|
||||
# reference search
|
||||
"""
|
||||
found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path']))
|
||||
if found and len(found) == 1:
|
||||
checkpoint_info = CheckpointInfo(found[0]['path']) # create a virutal model info
|
||||
checkpoint_info.type = 'huggingface'
|
||||
return checkpoint_info
|
||||
"""
|
||||
|
||||
# huggingface search
|
||||
if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1:
|
||||
modelloader.hf_login()
|
||||
|
|
@ -251,13 +239,10 @@ def model_hash(filename):
|
|||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
# t0 = time.time()
|
||||
m = hashlib.sha256()
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
shorthash = m.hexdigest()[0:8]
|
||||
# t1 = time.time()
|
||||
# shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
|
||||
return shorthash
|
||||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
|
|
@ -280,14 +265,11 @@ def select_checkpoint(op='model'):
|
|||
shared.log.info(" or use --ckpt-dir <path-to-folder> to specify folder with sd models")
|
||||
shared.log.info(" or use --ckpt <path-to-checkpoint> to force using specific model")
|
||||
return None
|
||||
# checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
if model_checkpoint is not None:
|
||||
if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0':
|
||||
shared.log.info(f'Load {op}: search="{model_checkpoint}" not found')
|
||||
else:
|
||||
shared.log.info("Selecting first available checkpoint")
|
||||
# shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
|
||||
# shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
||||
else:
|
||||
shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
|
||||
return checkpoint_info
|
||||
|
|
@ -367,8 +349,6 @@ def read_metadata_from_safetensors(filename):
|
|||
t1 = time.time()
|
||||
global sd_metadata_timer # pylint: disable=global-statement
|
||||
sd_metadata_timer += (t1 - t0)
|
||||
# except Exception as e:
|
||||
# shared.log.error(f"Error reading metadata from: {filename} {e}")
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -620,7 +620,7 @@ def load_diffuser(checkpoint_info=None, timer=None, op='model', revision=None):
|
|||
if debug_load:
|
||||
shared.log.trace(f'Model components: {list(get_signature(sd_model).values())}')
|
||||
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules import textual_inversion
|
||||
sd_model.embedding_db = textual_inversion.EmbeddingDatabase()
|
||||
sd_model.embedding_db.add_embedding_dir(shared.opts.embeddings_dir)
|
||||
sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
|
|
|||
|
|
@ -319,9 +319,6 @@ class DiffusionSampler:
|
|||
return
|
||||
|
||||
self.sampler = sampler
|
||||
if name == 'DC Solver':
|
||||
if not hasattr(self.sampler, 'dc_ratios'):
|
||||
pass
|
||||
|
||||
# shared.log.debug_log(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}')
|
||||
self.sampler.name = name
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from modules import shared, errors, paths, devices, sd_models, sd_detect
|
||||
|
||||
|
|
@ -12,35 +11,13 @@ loaded_vae_file = None
|
|||
checkpoint_info = None
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
|
||||
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
|
||||
unspecified = object()
|
||||
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
return None
|
||||
|
||||
|
||||
def store_base_vae(model):
|
||||
global base_vae, checkpoint_info # pylint: disable=global-statement
|
||||
if checkpoint_info != model.sd_checkpoint_info:
|
||||
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
||||
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||
checkpoint_info = model.sd_checkpoint_info
|
||||
|
||||
|
||||
def delete_base_vae():
|
||||
global base_vae, checkpoint_info # pylint: disable=global-statement
|
||||
base_vae = None
|
||||
checkpoint_info = None
|
||||
|
||||
|
||||
def restore_base_vae(model):
|
||||
global loaded_vae_file # pylint: disable=global-statement
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
shared.log.info("Restoring base VAE")
|
||||
_load_vae_dict(model, base_vae)
|
||||
loaded_vae_file = None
|
||||
delete_base_vae()
|
||||
def load_vae_dict(filename):
|
||||
vae_ckpt = sd_models.read_state_dict(filename, what='vae')
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict_1
|
||||
|
||||
|
||||
def get_filename(filepath):
|
||||
|
|
@ -114,35 +91,6 @@ def resolve_vae(checkpoint_file):
|
|||
return None, None
|
||||
|
||||
|
||||
def load_vae_dict(filename):
|
||||
vae_ckpt = sd_models.read_state_dict(filename, what='vae')
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict_1
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None, vae_source="unknown-source"):
|
||||
global loaded_vae_file # pylint: disable=global-statement
|
||||
if vae_file:
|
||||
try:
|
||||
if not os.path.isfile(vae_file):
|
||||
shared.log.error(f"VAE not found: model={vae_file} source={vae_source}")
|
||||
return
|
||||
store_base_vae(model)
|
||||
vae_dict_1 = load_vae_dict(vae_file)
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load VAE failed: model={vae_file} source={vae_source} {e}")
|
||||
if debug:
|
||||
errors.display(e, 'VAE')
|
||||
restore_base_vae(model)
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
|
||||
def apply_vae_config(model_file, vae_file, sd_model):
|
||||
def get_vae_config():
|
||||
config_file = os.path.join(paths.sd_configs_path, os.path.splitext(os.path.basename(model_file))[0] + '_vae.json')
|
||||
|
|
@ -219,20 +167,6 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"):
|
|||
return None
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
global loaded_vae_file # pylint: disable=global-statement
|
||||
loaded_vae_file = None
|
||||
|
||||
|
||||
unspecified = object()
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@ import os
|
|||
import time
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from PIL import Image
|
||||
from modules import shared, devices, errors
|
||||
from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed
|
||||
from modules.files_cache import directory_files, directory_mtime, extension_filter
|
||||
|
||||
|
||||
|
|
@ -241,9 +239,6 @@ class EmbeddingDatabase:
|
|||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
|
||||
def clear_embedding_dirs(self):
|
||||
self.embedding_dirs.clear()
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
if hasattr(model, 'cond_stage_model'):
|
||||
|
|
@ -306,45 +301,6 @@ class EmbeddingDatabase:
|
|||
errors.display(e, f'Load embedding: name="{embedding.name}" file="{embedding.filename}"')
|
||||
return
|
||||
|
||||
def load_from_file(self, path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
if '.preview' in filename.lower():
|
||||
return
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
if not data: # if data is None, means this is not an embeding, just a preview image
|
||||
return
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
data = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
if len(data.keys()) != 1:
|
||||
self.skipped_embeddings[name] = Embedding(None, name=name, filename=path)
|
||||
return
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding")
|
||||
|
||||
|
||||
def load_from_dir(self, embdir):
|
||||
if not shared.sd_loaded:
|
||||
shared.log.info('Skipping embeddings load: model not loaded')
|
||||
|
|
@ -387,14 +343,3 @@ class EmbeddingDatabase:
|
|||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
t1 = time.time()
|
||||
shared.log.info(f"Network load: type=embeddings loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}")
|
||||
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
possible_matches = self.ids_lookup.get(token, None)
|
||||
if possible_matches is None:
|
||||
return None, None
|
||||
for ids, embedding in possible_matches:
|
||||
if tokens[offset:offset + len(ids)] == ids:
|
||||
return embedding, len(ids)
|
||||
return None, None
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
import base64
|
||||
import json
|
||||
import zlib
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import torch
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class EmbeddingEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, torch.Tensor):
|
||||
return {'TORCHTENSOR': o.cpu().detach().numpy().tolist()}
|
||||
return json.JSONEncoder.default(self, o)
|
||||
|
||||
|
||||
class EmbeddingDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
||||
|
||||
def object_hook(self, d): # pylint: disable=E0202
|
||||
if 'TORCHTENSOR' in d:
|
||||
return torch.from_numpy(np.array(d['TORCHTENSOR']))
|
||||
return d
|
||||
|
||||
|
||||
def embedding_to_b64(data):
|
||||
d = json.dumps(data, cls=EmbeddingEncoder)
|
||||
return base64.b64encode(d.encode())
|
||||
|
||||
|
||||
def embedding_from_b64(data):
|
||||
d = base64.b64decode(data)
|
||||
return json.loads(d, cls=EmbeddingDecoder)
|
||||
|
||||
|
||||
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
|
||||
while True:
|
||||
seed = (a * seed + c) % m
|
||||
yield seed % 255
|
||||
|
||||
|
||||
def xor_block(block):
|
||||
blk = lcg()
|
||||
randblock = np.array([next(blk) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
|
||||
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
|
||||
|
||||
|
||||
def style_block(block, sequence):
|
||||
im = Image.new('RGB', (block.shape[1], block.shape[0]))
|
||||
draw = ImageDraw.Draw(im)
|
||||
i = 0
|
||||
for x in range(-6, im.size[0], 8):
|
||||
for yi, y in enumerate(range(-6, im.size[1], 8)):
|
||||
offset = 0
|
||||
if yi % 2 == 0:
|
||||
offset = 4
|
||||
shade = sequence[i % len(sequence)]
|
||||
i += 1
|
||||
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
|
||||
|
||||
fg = np.array(im).astype(np.uint8) & 0xF0
|
||||
|
||||
return block ^ fg
|
||||
|
||||
|
||||
def insert_image_data_embed(image, data):
|
||||
d = 3
|
||||
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
|
||||
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
|
||||
data_np_high = data_np_ >> 4
|
||||
data_np_low = data_np_ & 0x0F
|
||||
|
||||
h = image.size[1]
|
||||
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
||||
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
||||
|
||||
data_np_low = np.resize(data_np_low, next_size)
|
||||
data_np_low = data_np_low.reshape((h, -1, d))
|
||||
|
||||
data_np_high = np.resize(data_np_high, next_size)
|
||||
data_np_high = data_np_high.reshape((h, -1, d))
|
||||
|
||||
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
||||
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
|
||||
|
||||
data_np_low = style_block(data_np_low, sequence=edge_style)
|
||||
data_np_low = xor_block(data_np_low)
|
||||
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
|
||||
data_np_high = xor_block(data_np_high)
|
||||
|
||||
im_low = Image.fromarray(data_np_low, mode='RGB')
|
||||
im_high = Image.fromarray(data_np_high, mode='RGB')
|
||||
|
||||
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
|
||||
background.paste(im_low, (0, 0))
|
||||
background.paste(image, (im_low.size[0]+1, 0))
|
||||
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
|
||||
|
||||
return background
|
||||
|
||||
|
||||
def crop_black(img, tol=0):
|
||||
mask = (img > tol).all(2)
|
||||
mask0, mask1 = mask.any(0), mask.any(1)
|
||||
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
|
||||
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
|
||||
return img[row_start:row_end, col_start:col_end]
|
||||
|
||||
|
||||
def extract_image_data_embed(image):
|
||||
d = 3
|
||||
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F # pylint: disable=E1121
|
||||
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
|
||||
if black_cols[0].shape[0] < 2:
|
||||
return None
|
||||
|
||||
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
|
||||
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
|
||||
|
||||
data_block_lower = xor_block(data_block_lower)
|
||||
data_block_upper = xor_block(data_block_upper)
|
||||
|
||||
data_block = (data_block_upper << 4) | (data_block_lower)
|
||||
data_block = data_block.flatten().tobytes()
|
||||
|
||||
data = zlib.decompress(data_block)
|
||||
return json.loads(data, cls=EmbeddingDecoder)
|
||||
|
||||
|
||||
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
|
||||
from math import cos
|
||||
image = srcimage.copy()
|
||||
fontsize = 32
|
||||
if textfont is None:
|
||||
textfont = opts.font or 'javascript/notosans-nerdfont-regular.ttf'
|
||||
|
||||
factor = 1.5
|
||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||
for y in range(image.size[1]):
|
||||
mag = 1-cos(y/image.size[1]*factor)
|
||||
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
|
||||
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
|
||||
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
padding = 10
|
||||
|
||||
_, _, w, _h = draw.textbbox((0, 0), title, font=font)
|
||||
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
_, _, w, _h = draw.textbbox((0, 0), title, font=font)
|
||||
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
|
||||
|
||||
_, _, w, _h = draw.textbbox((0, 0), footerLeft, font=font)
|
||||
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
_, _, w, _h = draw.textbbox((0, 0), footerMid, font=font)
|
||||
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
_, _, w, _h = draw.textbbox((0, 0), footerRight, font=font)
|
||||
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
|
||||
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
|
||||
|
||||
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
|
||||
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
|
||||
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
testEmbed = Image.open('test_embedding.png')
|
||||
test_data = extract_image_data_embed(testEmbed)
|
||||
assert test_data is not None
|
||||
|
||||
test_data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
|
||||
assert test_data is not None
|
||||
|
||||
new_image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
|
||||
cap_image = caption_image_overlay(new_image, 'title', 'footerLeft', 'footerMid', 'footerRight')
|
||||
|
||||
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
|
||||
|
||||
embedded_image = insert_image_data_embed(cap_image, test_embed)
|
||||
|
||||
retrived_embed = extract_image_data_embed(embedded_image)
|
||||
|
||||
assert str(retrived_embed) == str(test_embed)
|
||||
|
||||
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
|
||||
|
||||
assert embedded_image == embedded_image2
|
||||
|
||||
g = lcg()
|
||||
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
|
||||
|
||||
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
|
||||
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
|
||||
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
|
||||
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
|
||||
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
|
||||
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
|
||||
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
|
||||
204, 86, 73, 222, 44, 198, 118, 240, 97]
|
||||
|
||||
assert shared_random == reference_random
|
||||
|
||||
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
|
||||
|
||||
assert 12731374 == hunna_kay_random_sum
|
||||
|
|
@ -67,10 +67,6 @@ def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument
|
|||
pass
|
||||
|
||||
|
||||
def ordered_ui_categories():
|
||||
return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
|
||||
|
||||
|
||||
def create_ui(startup_timer = None):
|
||||
if startup_timer is None:
|
||||
timer.startup = timer.Timer()
|
||||
|
|
|
|||
|
|
@ -323,19 +323,6 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args = No
|
|||
return refresh_button
|
||||
|
||||
|
||||
def create_browse_button(browse_component, elem_id):
|
||||
def browse(folder):
|
||||
# import subprocess
|
||||
if folder is not None:
|
||||
return gr.update(value = folder)
|
||||
return gr.update()
|
||||
|
||||
browse_button = ui_components.ToolButton(value=ui_symbols.folder, elem_id=elem_id)
|
||||
browse_button.click(fn=browse, _js="async () => await browseFolder()", inputs=[browse_component], outputs=[browse_component])
|
||||
# browse_button.click(fn=browse, inputs=[browse_component], outputs=[browse_component])
|
||||
return browse_button
|
||||
|
||||
|
||||
def create_override_inputs(tab): # pylint: disable=unused-argument
|
||||
with gr.Row(elem_id=f"{tab}_override_settings_row"):
|
||||
override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True)
|
||||
|
|
|
|||
|
|
@ -69,10 +69,6 @@ def list_extensions():
|
|||
debug(f'Extension installed without index: {entry}')
|
||||
|
||||
|
||||
def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
|
||||
|
||||
|
||||
def apply_changes(disable_list, update_list, disable_all):
|
||||
if shared.cmd_opts.disable_extension_access:
|
||||
shared.log.error('Extension: apply changes disallowed because public access is enabled and insecure is not specified')
|
||||
|
|
@ -126,18 +122,6 @@ def check_updates(_id_task, disable_list, search_text, sort_column):
|
|||
return create_html(search_text, sort_column), "Extension update complete | Restart required"
|
||||
|
||||
|
||||
def make_commit_link(commit_hash, remote, text=None):
|
||||
if text is None:
|
||||
text = commit_hash[:8]
|
||||
if remote.startswith("https://github.com/"):
|
||||
if remote.endswith(".git"):
|
||||
remote = remote[:-4]
|
||||
href = remote + "/commit/" + commit_hash
|
||||
return f'<a href="{href}" target="_blank">{text}</a>'
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def normalize_git_url(url):
|
||||
if url is None:
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -162,9 +162,6 @@ class ExtraNetworksPage:
|
|||
preview = f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
||||
return preview
|
||||
|
||||
def is_empty(self, folder):
|
||||
return any(files_cache.list_files(folder, ext_filter=['.ckpt', '.safetensors', '.pt', '.json']))
|
||||
|
||||
def create_thumb(self):
|
||||
debug(f'EN create-thumb: {self.name}')
|
||||
created = 0
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
from modules import shared, sd_models, ui_extra_networks, files_cache
|
||||
from modules.textual_inversion.textual_inversion import Embedding
|
||||
from modules.textual_inversion import Embedding
|
||||
|
||||
|
||||
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
|
|
|
|||
|
|
@ -115,20 +115,5 @@ def reload_javascript():
|
|||
gradio.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
||||
def setup_ui_api(app):
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from typing import List
|
||||
|
||||
class QuicksettingsHint(BaseModel): # pylint: disable=too-few-public-methods
|
||||
name: str = Field(title="Name of the quicksettings field")
|
||||
label: str = Field(title="Label of the quicksettings field")
|
||||
|
||||
def quicksettings_hint():
|
||||
return [QuicksettingsHint(name=k, label=v.label) for k, v in shared.opts.data_labels.items()]
|
||||
|
||||
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
||||
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
||||
|
||||
|
||||
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
||||
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
|
||||
|
|
|
|||
|
|
@ -99,20 +99,6 @@ def create_interrogate_button(tab: str, inputs: list = None, outputs: str = None
|
|||
return button_interrogate
|
||||
|
||||
|
||||
def create_interrogate_buttons(tab): # legacy function
|
||||
button_interrogate = gr.Button(ui_symbols.int_clip, elem_id=f"{tab}_interrogate", elem_classes=['interrogate-clip'])
|
||||
button_deepbooru = gr.Button(ui_symbols.int_blip, elem_id=f"{tab}_deepbooru", elem_classes=['interrogate-blip'])
|
||||
return button_interrogate, button_deepbooru
|
||||
|
||||
|
||||
def create_sampler_inputs(tab, accordion=True):
|
||||
with gr.Accordion(open=False, label="Sampler", elem_id=f"{tab}_sampler", elem_classes=["small-accordion"]) if accordion else gr.Group():
|
||||
with gr.Row(elem_id=f"{tab}_row_sampler"):
|
||||
sd_samplers.set_samplers()
|
||||
steps, sampler_index = create_sampler_and_steps_selection(sd_samplers.samplers, tab)
|
||||
return steps, sampler_index
|
||||
|
||||
|
||||
def create_batch_inputs(tab, accordion=True):
|
||||
with gr.Accordion(open=False, label="Batch", elem_id=f"{tab}_batch", elem_classes=["small-accordion"]) if accordion else gr.Group():
|
||||
with gr.Row(elem_id=f"{tab}_row_batch"):
|
||||
|
|
|
|||
|
|
@ -48,10 +48,6 @@ def get_value_for_setting(key):
|
|||
return gr.update(value=value, **args)
|
||||
|
||||
|
||||
def ordered_ui_categories():
|
||||
return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented
|
||||
|
||||
|
||||
def create_setting_component(key, is_quicksettings=False):
|
||||
def fun():
|
||||
return shared.opts.data[key] if key in shared.opts.data else shared.opts.data_labels[key].default
|
||||
|
|
@ -88,7 +84,6 @@ def create_setting_component(key, is_quicksettings=False):
|
|||
elif info.folder is not None:
|
||||
with gr.Row():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, elem_classes="folder-selector", **args)
|
||||
# ui_common.create_browse_button(res, f"folder_{key}")
|
||||
else:
|
||||
try:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **args)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,9 @@
|
|||
import gradio as gr
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
|
||||
from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing, processing_vae, devices, images
|
||||
from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing_vae, images
|
||||
from modules.ui_components import ToolButton # pylint: disable=unused-import
|
||||
|
||||
|
||||
def calc_resolution_hires(width, height, hr_scale, hr_resize_x, hr_resize_y, hr_upscaler):
|
||||
if hr_upscaler == "None":
|
||||
return "Hires resize: None"
|
||||
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
|
||||
p.init_hr()
|
||||
with devices.autocast():
|
||||
p.init([""], [0], [0])
|
||||
return f"Hires resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
||||
|
||||
|
||||
def create_ui():
|
||||
shared.log.debug('UI initialize: txt2img')
|
||||
import modules.txt2img # pylint: disable=redefined-outer-name
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
from .sampler import UniPCSampler
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
|
||||
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC, get_time_steps
|
||||
from modules import shared, devices
|
||||
|
||||
|
||||
class UniPCSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.before_sample = None
|
||||
self.after_sample = None
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
self.noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
# persist steps so we can eventually find denoising strength
|
||||
self.inflated_steps = ddim_num_steps
|
||||
|
||||
@devices.inference_context()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
|
||||
# first time we have all the info to get the real parameters from the ui
|
||||
# value from the hires steps slider:
|
||||
num_inference_steps = t[0] + 1
|
||||
num_inference_steps / self.inflated_steps
|
||||
self.denoise_steps = max(num_inference_steps, shared.opts.schedulers_solver_order)
|
||||
|
||||
max(self.inflated_steps - self.denoise_steps, 0)
|
||||
|
||||
# actual number of steps we'll run
|
||||
|
||||
all_timesteps = get_time_steps(
|
||||
self.noise_schedule,
|
||||
shared.opts.uni_pc_skip_type,
|
||||
self.noise_schedule.T,
|
||||
1./self.noise_schedule.total_N,
|
||||
self.inflated_steps+1,
|
||||
t.device,
|
||||
)
|
||||
|
||||
# the rest of the timesteps will be used for denoising
|
||||
self.timesteps = all_timesteps[-(self.denoise_steps+1):]
|
||||
|
||||
latent_timestep = (
|
||||
( # get the timestep of our first denoise step
|
||||
self.timesteps[:1]
|
||||
# multiply by number of alphas to get int index
|
||||
* self.noise_schedule.total_N
|
||||
).int() - 1 # minus one for 0-indexed
|
||||
).repeat(x0.shape[0])
|
||||
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
sqrt_alpha_prod = alphas_cumprod[latent_timestep] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(x0.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[latent_timestep]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(x0.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
return (sqrt_alpha_prod * x0 + sqrt_one_minus_alpha_prod * noise)
|
||||
|
||||
def decode(self, x_latent, conditioning, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None):
|
||||
# same as in .sample(), i guess
|
||||
model_type = "v" if self.model.parameterization == "v" else "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
self.noise_schedule,
|
||||
model_type=model_type,
|
||||
guidance_type="classifier-free",
|
||||
#condition=conditioning,
|
||||
#unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
self.uni_pc = UniPC(
|
||||
model_fn,
|
||||
self.noise_schedule,
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
variant=shared.opts.uni_pc_variant,
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
before_sample=self.before_sample,
|
||||
after_sample=self.after_sample,
|
||||
after_update=self.after_update,
|
||||
)
|
||||
|
||||
return self.uni_pc.sample(
|
||||
x_latent,
|
||||
steps=self.denoise_steps,
|
||||
skip_type=shared.opts.uni_pc_skip_type,
|
||||
method="multistep",
|
||||
order=shared.opts.schedulers_solver_order,
|
||||
lower_order_final=shared.opts.schedulers_use_loworder,
|
||||
denoise_to_zero=True,
|
||||
timesteps=self.timesteps,
|
||||
)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(devices.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def set_hooks(self, before_sample, after_sample, after_update):
|
||||
self.before_sample = before_sample
|
||||
self.after_sample = after_sample
|
||||
self.after_update = after_update
|
||||
|
||||
@devices.inference_context()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
shared.log.warning(f"UniPC: got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
shared.log.warning(f"UniPC: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
shared.log.warning(f"UniPC: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
# SD 1.X is "noise", SD 2.X is "v"
|
||||
model_type = "v" if self.model.parameterization == "v" else "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
self.noise_schedule,
|
||||
model_type=model_type,
|
||||
guidance_type="classifier-free",
|
||||
#condition=conditioning,
|
||||
#unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
uni_pc = UniPC(model_fn, self.noise_schedule, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
||||
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||
|
||||
return x.to(device), None
|
||||
|
|
@ -1,779 +0,0 @@
|
|||
import torch
|
||||
import math
|
||||
import time
|
||||
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
|
||||
from modules import shared, devices
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
self,
|
||||
schedule='discrete',
|
||||
betas=None,
|
||||
alphas_cumprod=None,
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
):
|
||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
||||
|
||||
self.schedule = schedule
|
||||
if schedule == 'discrete':
|
||||
if betas is not None:
|
||||
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
||||
else:
|
||||
assert alphas_cumprod is not None
|
||||
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||
self.total_N = len(log_alphas)
|
||||
self.T = 1.
|
||||
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
||||
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
||||
else:
|
||||
self.total_N = 1000
|
||||
self.beta_0 = continuous_beta_0
|
||||
self.beta_1 = continuous_beta_1
|
||||
self.cosine_s = 0.008
|
||||
self.cosine_beta_max = 999.
|
||||
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
||||
self.schedule = schedule
|
||||
if schedule == 'cosine':
|
||||
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
||||
self.T = 0.9946
|
||||
else:
|
||||
self.T = 1.
|
||||
|
||||
def marginal_log_mean_coeff(self, t):
|
||||
"""
|
||||
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
if self.schedule == 'discrete':
|
||||
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
||||
elif self.schedule == 'linear':
|
||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
elif self.schedule == 'cosine':
|
||||
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
return log_alpha_t
|
||||
|
||||
def marginal_alpha(self, t):
|
||||
"""
|
||||
Compute alpha_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||
|
||||
def marginal_std(self, t):
|
||||
"""
|
||||
Compute sigma_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||
|
||||
def marginal_lambda(self, t):
|
||||
"""
|
||||
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def inverse_lambda(self, lamb):
|
||||
"""
|
||||
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||
"""
|
||||
if self.schedule == 'linear':
|
||||
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
Delta = self.beta_0**2 + tmp
|
||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||
elif self.schedule == 'discrete':
|
||||
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
||||
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
||||
return t.reshape((-1,))
|
||||
else:
|
||||
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
t = t_fn(log_alpha)
|
||||
return t
|
||||
|
||||
|
||||
def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs=None,
|
||||
guidance_type="uncond",
|
||||
#condition=None,
|
||||
#unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs=None,
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
|
||||
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
||||
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
||||
|
||||
We support four types of the diffusion model by setting `model_type`:
|
||||
|
||||
1. "noise": noise prediction model. (Trained by predicting noise).
|
||||
|
||||
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
||||
|
||||
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
||||
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
||||
|
||||
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
||||
arXiv preprint arXiv:2202.00512 (2022).
|
||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||
arXiv preprint arXiv:2210.02303 (2022).
|
||||
|
||||
4. "score": marginal score function. (Trained by denoising score matching).
|
||||
Note that the score function and the noise prediction model follows a simple relationship:
|
||||
```
|
||||
noise(x_t, t) = -sigma_t * score(x_t, t)
|
||||
```
|
||||
|
||||
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
||||
1. "uncond": unconditional sampling by DPMs.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
|
||||
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
|
||||
The input `classifier_fn` has the following format:
|
||||
``
|
||||
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
||||
``
|
||||
|
||||
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
||||
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
||||
|
||||
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||
|
||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||
arXiv preprint arXiv:2207.12598 (2022).
|
||||
|
||||
|
||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||
or continuous-time labels (i.e. epsilon to T).
|
||||
|
||||
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
||||
``
|
||||
def model_fn(x, t_continuous) -> noise:
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
``
|
||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||
|
||||
===============================================================
|
||||
|
||||
Args:
|
||||
model: A diffusion model with the corresponding format described above.
|
||||
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
||||
model_type: A `str`. The parameterization type of the diffusion model.
|
||||
"noise" or "x_start" or "v" or "score".
|
||||
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
||||
guidance_type: A `str`. The type of the guidance for sampling.
|
||||
"uncond" or "classifier" or "classifier-free".
|
||||
condition: A pytorch tensor. The condition for the guided sampling.
|
||||
Only used for "classifier" or "classifier-free" guidance type.
|
||||
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
||||
Only used for "classifier-free" guidance type.
|
||||
guidance_scale: A `float`. The scale for the guided sampling.
|
||||
classifier_fn: A classifier function. Only used for the classifier guidance.
|
||||
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
||||
Returns:
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
model_kwargs = model_kwargs or {}
|
||||
classifier_kwargs = classifier_kwargs or {}
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
||||
For continuous-time DPMs, we just use `t_continuous`.
|
||||
"""
|
||||
if noise_schedule.schedule == 'discrete':
|
||||
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
||||
else:
|
||||
return t_continuous
|
||||
|
||||
def noise_pred_fn(x, t_continuous, cond=None):
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
if cond is None:
|
||||
output = model(x, t_input, None, **model_kwargs)
|
||||
else:
|
||||
output = model(x, t_input, cond, **model_kwargs)
|
||||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
||||
elif model_type == "v":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
||||
elif model_type == "score":
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return -expand_dims(sigma_t, dims) * output
|
||||
|
||||
def cond_grad_fn(x, t_input, condition):
|
||||
"""
|
||||
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
||||
"""
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
||||
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
||||
|
||||
def model_fn(x, t_continuous, condition, unconditional_condition):
|
||||
"""
|
||||
The noise predicition model function that is used for DPM-Solver.
|
||||
"""
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
if guidance_type == "uncond":
|
||||
return noise_pred_fn(x, t_continuous)
|
||||
elif guidance_type == "classifier":
|
||||
assert classifier_fn is not None
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
cond_grad = cond_grad_fn(x, t_input, condition)
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
noise = noise_pred_fn(x, t_continuous)
|
||||
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
||||
elif guidance_type == "classifier-free":
|
||||
if guidance_scale == 1. or unconditional_condition is None:
|
||||
return noise_pred_fn(x, t_continuous, cond=condition)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = {}
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_condition[k][i],
|
||||
condition[k][i]]) for i in range(len(condition[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_condition[k],
|
||||
condition[k]])
|
||||
elif isinstance(condition, list):
|
||||
c_in = []
|
||||
assert isinstance(unconditional_condition, list)
|
||||
for i in range(len(condition)):
|
||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
assert model_type in ["noise", "x_start", "v"]
|
||||
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
||||
return model_fn
|
||||
|
||||
def get_time_steps(noise_schedule, skip_type, t_T, t_0, N, device):
|
||||
"""Compute the intermediate time steps for sampling.
|
||||
"""
|
||||
if skip_type == 'logSNR':
|
||||
lambda_T = noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||
lambda_0 = noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||
return noise_schedule.inverse_lambda(logSNR_steps)
|
||||
elif skip_type == 'time_uniform':
|
||||
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||
elif skip_type == 'time_quadratic':
|
||||
t_order = 2
|
||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
else:
|
||||
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
||||
|
||||
class UniPC:
|
||||
def __init__(
|
||||
self,
|
||||
model_fn,
|
||||
noise_schedule,
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
variant='bh1',
|
||||
condition=None,
|
||||
unconditional_condition=None,
|
||||
before_sample=None,
|
||||
after_sample=None,
|
||||
after_update=None
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
We support both data_prediction and noise_prediction.
|
||||
"""
|
||||
self.model_fn_ = model_fn
|
||||
self.noise_schedule = noise_schedule
|
||||
self.variant = variant
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
self.condition = condition
|
||||
self.unconditional_condition = unconditional_condition
|
||||
self.before_sample = before_sample
|
||||
self.after_sample = after_sample
|
||||
self.after_update = after_update
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
"""
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def model(self, x, t):
|
||||
cond = self.condition
|
||||
uncond = self.unconditional_condition
|
||||
if self.before_sample is not None:
|
||||
x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
|
||||
res = self.model_fn_(x, t, cond, uncond)
|
||||
if self.after_sample is not None:
|
||||
x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
|
||||
|
||||
if isinstance(res, tuple):
|
||||
# (None, pred_x0)
|
||||
res = res[1]
|
||||
|
||||
return res
|
||||
|
||||
def noise_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the data prediction model (with thresholding).
|
||||
"""
|
||||
noise = self.noise_prediction_fn(x, t)
|
||||
dims = x.dim()
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
||||
if self.thresholding:
|
||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
return self.data_prediction_fn(x, t)
|
||||
else:
|
||||
return self.noise_prediction_fn(x, t)
|
||||
|
||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||
"""
|
||||
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||
"""
|
||||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [3,] * (K - 2) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [3,] * (K - 1) + [1]
|
||||
else:
|
||||
orders = [3,] * (K - 1) + [2]
|
||||
elif order == 2:
|
||||
if steps % 2 == 0:
|
||||
K = steps // 2
|
||||
orders = [2,] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [2,] * (K - 1) + [1]
|
||||
elif order == 1:
|
||||
K = steps
|
||||
orders = [1,] * steps
|
||||
else:
|
||||
raise ValueError("'order' must be '1' or '2' or '3'.")
|
||||
if skip_type == 'logSNR':
|
||||
# To reproduce the results in DPM-Solver paper
|
||||
timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, K, device)
|
||||
else:
|
||||
timesteps_outer = get_time_steps(self.noise_schedule, skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
||||
return timesteps_outer, orders
|
||||
|
||||
def denoise_to_zero_fn(self, x, s):
|
||||
"""
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
"""
|
||||
return self.data_prediction_fn(x, s)
|
||||
|
||||
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
||||
if len(t.shape) == 0:
|
||||
t = t.view(-1)
|
||||
if 'bh' in self.variant:
|
||||
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
else:
|
||||
assert self.variant == 'vary_coeff'
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
K = len(rks)
|
||||
# build C matrix
|
||||
C = []
|
||||
|
||||
col = torch.ones_like(rks)
|
||||
for k in range(1, K + 1):
|
||||
C.append(col)
|
||||
col = col * rks / (k + 1)
|
||||
C = torch.stack(C, dim=1)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh)
|
||||
h_phi_ks = []
|
||||
factorial_k = 1
|
||||
h_phi_k = h_phi_1
|
||||
for k in range(1, K + 2):
|
||||
h_phi_ks.append(h_phi_k)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
||||
factorial_k *= (k + 1)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
sigma_t / sigma_prev_0 * x
|
||||
- alpha_t * h_phi_1 * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
else:
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
x_t_ = (
|
||||
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
||||
- (sigma_t * h_phi_1) * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
return x_t, model_t
|
||||
|
||||
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
dims = x.dim()
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h[0] if self.predict_x0 else h[0]
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.variant == 'bh1':
|
||||
B_h = hh
|
||||
elif self.variant == 'bh2':
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=x.device)
|
||||
|
||||
# now predictor
|
||||
use_predictor = len(D1s) > 0 and x_t is None
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
if x_t is None:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if use_corrector:
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
||||
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
||||
)
|
||||
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = (
|
||||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
||||
)
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t, model_t
|
||||
|
||||
|
||||
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05, corrector=False, timesteps=None,
|
||||
):
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
if method == 'multistep':
|
||||
if timesteps is None:
|
||||
timesteps = get_time_steps(self.noise_schedule, skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
assert steps >= order, "UniPC order must be < sampling steps"
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=shared.console) as progress:
|
||||
task = progress.add_task(description="Initializing", total=steps)
|
||||
t = time.time()
|
||||
with devices.inference_context():
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
t_prev_list = [vec_t]
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * init_order / (time.time() - t), 2)}it/s")
|
||||
# for step in trange(order, steps + 1):
|
||||
for step in range(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
if step == steps:
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
progress.update(task, advance=1, description=f"Progress {round(len(vec_t) * step / (time.time() - t), 2)}it/s")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if denoise_to_zero:
|
||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||
return x
|
||||
|
||||
|
||||
#############################################################
|
||||
# other utility functions
|
||||
#############################################################
|
||||
|
||||
def interpolate_fn(x, xp, yp):
|
||||
"""
|
||||
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
||||
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
||||
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
||||
|
||||
Args:
|
||||
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
||||
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
||||
yp: PyTorch tensor with shape [C, K].
|
||||
Returns:
|
||||
The function values f(x), with shape [N, C].
|
||||
"""
|
||||
N, K = x.shape[0], xp.shape[1]
|
||||
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
||||
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
||||
x_idx = torch.argmin(x_indices, dim=2)
|
||||
cand_start_idx = x_idx - 1
|
||||
start_idx = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(1, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||
start_idx2 = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(0, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
return cand
|
||||
|
||||
|
||||
def expand_dims(v, dims):
|
||||
"""
|
||||
Expand the tensor `v` to the dim `dims`.
|
||||
|
||||
Args:
|
||||
`v`: a PyTorch tensor with shape [N].
|
||||
`dim`: a `int`.
|
||||
Returns:
|
||||
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||
"""
|
||||
return v[(...,) + (None,)*(dims - 1)]
|
||||
|
|
@ -2,40 +2,10 @@ import math
|
|||
import gradio as gr
|
||||
from modules import images, scripts_manager
|
||||
from modules.processing import process_images
|
||||
from modules.shared import opts, state, log
|
||||
from modules.shared import opts, log
|
||||
import modules.sd_samplers
|
||||
|
||||
|
||||
def draw_xy_grid(xs, ys, x_label, y_label, cell):
|
||||
res = []
|
||||
|
||||
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
|
||||
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
|
||||
|
||||
first_processed = None
|
||||
|
||||
state.job_count = len(xs) * len(ys)
|
||||
|
||||
for iy, y in enumerate(ys):
|
||||
for ix, x in enumerate(xs):
|
||||
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
||||
|
||||
processed, _t = cell(x, y)
|
||||
if first_processed is None:
|
||||
first_processed = processed
|
||||
|
||||
res.append(processed.images[0])
|
||||
|
||||
if images.check_grid_size(res):
|
||||
grid = images.image_grid(res, rows=len(ys))
|
||||
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
|
||||
first_processed.images = [grid]
|
||||
else:
|
||||
first_processed.images = res
|
||||
|
||||
return first_processed
|
||||
|
||||
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return "Prompt matrix"
|
||||
|
|
|
|||
2
webui.py
2
webui.py
|
|
@ -31,7 +31,7 @@ import modules.upscaler
|
|||
import modules.upscaler_simple
|
||||
import modules.extra_networks
|
||||
import modules.ui_extra_networks
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.textual_inversion
|
||||
import modules.script_callbacks
|
||||
import modules.api.middleware
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue