stable-diffusion-webui-dump.../scripts/dumpunet/layer_prompt/prompt.py

211 lines
8.0 KiB
Python

# Thanks for kohya.S
# https://note.com/kohya_ss/n/n93b7c01b0547
import sys
from collections import defaultdict
from typing import Callable, Any
import torch
from torch import Tensor
from ldm.modules.diffusionmodules.util import timestep_embedding # type: ignore
from modules.processing import StableDiffusionProcessing
from modules import devices, prompt_parser
from modules.prompt_parser import get_learned_conditioning as glc
from modules.prompt_parser import ScheduledPromptConditioning
glc : Callable[[Any,list[str],int],list[list[ScheduledPromptConditioning]]]
from scripts.dumpunet import layerinfo
from scripts.dumpunet.layer_prompt.generator import LayerPromptGenerator, LayerPromptEraseGenerator, LayerPrompts
from scripts.dumpunet.layer_prompt.parser import BadPromptError
class LayerPrompt:
# original forward function
o_fw: Callable|None
# hooker
fw: Callable|None
model: Any
remove_layer_prompts: bool
last_batch_num: int
def __init__(self, runner, enabled: bool, prompts_to_stdout: bool):
self._runner = runner
self._enabled = enabled
self._show = prompts_to_stdout
self.o_fw = self.fw = self.model = None
self.remove_layer_prompts = False
self.last_batch_num = -1
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
if self.model is not None and self.o_fw is not None:
self.model.diffusion_model.forward = self.o_fw
self.o_fw = self.fw = self.model = None
self.remove_layer_prompts = False
self.last_batch_num = -1
def setup(self, p: StableDiffusionProcessing, remove_layer_prompts=False):
if not self._enabled:
return
self.remove_layer_prompts = remove_layer_prompts
old = p.sd_model.model.diffusion_model.forward # type: ignore
new = self._create_forward_fn(
p.sd_model.model.diffusion_model, # type: ignore
p
)
self.o_fw = old
self.fw = new
# replace U-Net forward
self.model = p.sd_model.model # type: ignore
self.model.diffusion_model.forward = self.fw # type: ignore
def _create_blocks_cond(
self,
p: StableDiffusionProcessing,
all_prompts: list[str],
all_negative_prompts: list[str]
):
try:
if self.remove_layer_prompts:
gen = LayerPromptEraseGenerator()
else:
gen = LayerPromptGenerator()
uc_list = [gen.generate(uc) for uc in all_negative_prompts]
c_list = [gen.generate(c) for c in all_prompts]
except BadPromptError as pe:
self._runner.notify_error(pe)
print("\033[31m", file=sys.stderr, end="", flush=False)
print(pe.message(), file=sys.stderr, flush=False)
print("\033[0m", file=sys.stderr, end="")
raise ValueError(f"Prompt syntax error at pos={pe.pos}. See stderr for details.") from None
# get conditional
with devices.autocast():
# ignore scheduled cond
uc1 = [ glc(p.sd_model, list(image_uc.values()), p.steps) for image_uc in uc_list ]
c1 = [ glc(p.sd_model, list(image_c.values()), p.steps) for image_c in c_list ]
steps = self._runner.steps_on_batch
def get_cond(conds: list[ScheduledPromptConditioning]):
assert len(conds) != 0
for cc in conds:
if steps <= cc.end_at_step:
return cc
return conds[-1]
# c,uc : img_idx -> layer_idx -> Tensor
# layer_idx -> [0(c), 1(uc)] -> img_idx -> Tensor
layer_conds : defaultdict[int,tuple[list[Tensor],list[Tensor]]]
layer_conds = defaultdict(lambda: ([],[]))
for img_idx, (uc2, c2) in enumerate(zip(uc1, c1)):
assert len(uc2) == len(layerinfo.Names)
assert len(c2) == len(layerinfo.Names)
for layer_idx, (uc, c) in enumerate(zip(uc2, c2)):
layer_conds[layer_idx][0].append(get_cond(c).cond)
layer_conds[layer_idx][1].append(get_cond(uc).cond)
# layer_idx -> [0(c), 1(uc)] -> Tensor
layer_conds2 : list[tuple[Tensor,Tensor]]
layer_conds2 = []
for layer_idx in range(len(layerinfo.Names)):
c, uc = layer_conds[layer_idx]
layer_conds2.append(
(
torch.cat([ cx.unsqueeze(0) for cx in c ]),
torch.cat([ ucx.unsqueeze(0) for ucx in uc ])
)
)
# layer_idx -> Tensor
layer_conds3 : list[Tensor]
layer_conds3 = []
for layer_idx in range(len(layerinfo.Names)):
c, uc = layer_conds2[layer_idx]
layer_conds3.append(torch.cat([c, uc]))
return layer_conds3, c_list, uc_list
def _create_forward_fn(self, _self, p):
def new_forward(x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
# print("replaced", x.size(), timesteps, context.size() if context is not None else context)
assert (y is not None) == (
_self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, _self.model_channels, repeat_only=False)
emb = _self.time_embed(t_emb)
if _self.num_classes is not None:
assert y is not None
assert y.shape[0] == x.shape[0]
emb = emb + _self.label_emb(y)
n = self._runner.batch_num - 1
assert 0 <= n
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
assert len(prompts) != 0
layer_conds, c_list, uc_list = \
self._create_blocks_cond(p, prompts, negative_prompts)
last_n, self.last_batch_num = self.last_batch_num, n
if self._show and last_n != n:
self.dump_prompts(n, c_list, uc_list)
cond_index = 0
h = x.type(_self.dtype)
for module in _self.input_blocks:
h = module(h, emb, layer_conds[cond_index] if context is not None else None) # context)
cond_index += 1
hs.append(h)
h = _self.middle_block(h, emb, layer_conds[cond_index] if context is not None else None) # context)
cond_index += 1
for module in _self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, layer_conds[cond_index] if context is not None else None) # context)
cond_index += 1
h = h.type(x.dtype)
if _self.predict_codebook_ids:
return _self.id_predictor(h)
else:
return _self.out(h)
return new_forward
def dump_prompts(self, n: int, c_list, uc_list, file=sys.stdout):
def pp(s, file=file): print(s, file=file)
pp("=" * 80)
pp(f"Prompts (batch={n}, step={self._runner.steps_on_batch}")
pp("-" * 80)
for layername in layerinfo.Names:
pp(f"{layername:<5} : {[getattr(c, layername) for c in c_list]}")
pp("=" * 80)
pp("Negative Prompts")
pp("-" * 80)
for layername in layerinfo.Names:
pp(f"{layername:<5} : {[getattr(uc, layername) for uc in uc_list]}")
pp("=" * 80)