add replace mode

pull/8/head v1.2
Kahsolt 2022-11-14 23:46:56 +08:00
parent 6be9c36936
commit 081697ff19
3 changed files with 383 additions and 99 deletions

View File

@ -8,7 +8,7 @@ This is the more human-sensible version of [stable-diffusion-webui-prompt-erosio
now we do not modify on text char level, but do linear interpolating on the hidden embedded vectors. 😀
⚠ 我们成立了插件反馈 QQ 群: 616795645 (赤狐屿)欢迎出建议、意见、报告bug等 (w
⚠ We have a QQ chat group now: 616795645, any suggeustion, discussion and bug reports are highly wellllcome !!
⚠ We have a QQ chat group now: 616795645, any suggeustions, discussions and bug reports are highly wellllcome !!
实话不说我想有可能通过这个来做ppt童话绘本<del>甚至本子</del>……
聪明的用法:先手工盲搜两张好看的图 (只有prompt差异),然后再尝试在其间 travel :lolipop:
@ -16,7 +16,7 @@ now we do not modify on text char level, but do linear interpolating on the hidd
### Change Log
- 2022/11/14: walk by substituting word embedding ('replace' mode)
- 2022/11/14: walk by substituting token embedding ('replace' mode)
- 2022/11/13: walk by optimizing condition ('grad' mode)
- 2022/11/10: interpolate linearly on condition/uncondition ('linear' mode)
@ -47,17 +47,19 @@ now we do not modify on text char level, but do linear interpolating on the hidd
- mode: (categorical)
- linear: interpolate linearly on condition/uncondition in latent space
- replace: walk by gradually substituting word embededings
- grad: walk by optimizing certain loss (see [Experimental](#experimental))
- NOTE: `walk` methods might not reach target stages in specified steps, manually tune `grad_alpha` or increase `steps` in that case accroding to log losses...
- grad: walk by optimizing certain loss
- NOTE: `walk` methods might not reach target stages in specified steps some times, or reached earlier than expect, in that case, manually tune `grad_alpha` and `steps` might help a little...
- steps: (int, list of int)
- number of images to interpolate between two successive stages<del>, set `-1` to allow wanderding util converge for `walk` methods (not yet implemented)</del>
- number of images to interpolate between two successive stages
- if int, constant number of travel steps
- if list of int, length should match `len(stages)-1`, separate by comma, e.g.: `12, 24, 36`
- replace_*
- replace_order: (categorical)
- `random`: substitute tokens randomly
- `similiar`: substitute most similiar tokens first (L1 distance of token embeddings)
- `different`: substitute most diffrent tokens first (L1 distance of token embeddings)
- `different`: substitute most diffrent tokens first
- `grad_min`: substitute tokens that causing smallest gradient first (gradient settings same as in `grad` mode)
- `grad_max`: substitute tokens that causing largest gradient first
- grad_*
- grad_alpha: (float), step size of a walk pace
- grad_iter: (int), step count of walk paces
@ -68,9 +70,9 @@ now we do not modify on text char level, but do linear interpolating on the hidd
- `sign`: walk at a constant speed (often stucks into oscillation at the end)
- `tanh`: significantly speed down when approching (it takes infinite time to exactly reach...)
- grad_w_latent: (float), weight factor of `loss_latent`
- grad_w_match: (float), weight factor of `loss_cond`
- grad_w_cond: (float), weight factor of `loss_cond`
- fps: (float)
- FPS of video, set 0 to disable saving
- FPS of video, set `0` to disable saving
- debug: (bool)
- whether show verbose debug info at console
@ -117,6 +119,13 @@ Grid search results: (`steps=100, grad_alpha=0.01, grad_iter=1, grad_meth='clip'
NOTE: When 'prompt' has only single line, it will wander just **around** the init stage, dynamically balancing `loss_latent` and `loss_cond`; this allows you to discover neighbors of your given prompt 😀
⚪ 'replace' mode
This mode working on token embed input level, hence your can view `log.txt` to see how your input tokens are gradually changed.
⚠ Remeber that comma is a normal valid token, so you might see many commas there. However, they are different when appearing at different positions within the token sequence.
The actual token replacing order might reveal some information of the token importances, might the listed '>> grad ascend' or '>> embed L1-distance ascend' give you some ideas to tune your input prompt (I wish so..)
----
by Armit

View File

@ -1,4 +1,4 @@
import launch
if not launch.is_installed("moviepy"):
launch.run_pip("install moviepy==1.0.3", "requirements for Seed Travel")
launch.run_pip("install moviepy==1.0.3", "requirements for Prompt Travel to generate video")

View File

@ -1,38 +1,44 @@
import os
import random
from pathlib import Path
from copy import deepcopy
from PIL import Image
from typing import List, Tuple, Union
from traceback import print_exc
import gradio as gr
import torch.nn.functional as F
import numpy as np
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError:
print(f"moviepy python module not installed. Will not be able to generate video.")
try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError: print(f"package moviepy not installed, will not be able to generate video")
import modules.scripts as scripts
from modules.processing import Processed, StableDiffusionProcessing
from modules.processing import *
from modules.prompt_parser import ScheduledPromptConditioning, MulticondLearnedConditioning
from modules.shared import state
DEFAULT_MODE = 'linear'
DEFAULT_STEPS = 30
DEFAULT_REPLACE_ORDER = 'random'
DEFAULT_REPLACE_ORDER = 'grad_min'
DEFAULT_GRAD_ALPHA = 0.01
DEFAULT_GRAD_ITERS = 1
DEFAULT_GRAD_ITER = 1
DEFAULT_GRAD_METH = 'clip'
DEFAULT_GRAD_W_LATENT = 1
DEFAULT_GRAD_W_COND = 1
DEFAULT_GRAD_W_LATENT = 1
DEFAULT_GRAD_W_COND = 1
DEFAULT_FPS = 10
DEFAULT_DEBUG = True
CHOICES_MODE = ['linear', 'replace', 'grad']
CHOICES_REPLACE_ORDER = ['random', 'similar', 'different']
CHOICES_REPLACE_ORDER = ['random', 'most_similar', 'most_different', 'grad_min', 'grad_max']
CHOICES_GRAD_METH = ['clip', 'sign', 'tanh']
T_tokens = List[List[float]]
T_weights = List[List[int]]
# ↓↓↓ the following is modified from 'modules/processing.py' ↓↓↓
from modules.processing import apply_overlay, apply_color_correction, create_infotext, decode_first_stage, get_fixed_seed
import torch
import numpy as np
from PIL import Image
@ -46,7 +52,7 @@ import modules.face_restoration
import modules.images as images
import modules.styles
def process_images_inner_half_A(p: StableDiffusionProcessing) -> tuple:
def process_images_prompt_to_cond(p: StableDiffusionProcessing, ret_token_and_weight=False) -> tuple:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert p.prompt is not None
@ -91,12 +97,16 @@ def process_images_inner_half_A(p: StableDiffusionProcessing) -> tuple:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
# 'prompt string' => tensor([T, D])
uc, uc_tokens, uc_weights = get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c, c_tokens, c_weights = get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if ret_token_and_weight:
return c, uc, prompts, seeds, subseeds, (c_tokens, c_weights, uc_tokens, uc_weights)
else:
return c, uc, prompts, seeds, subseeds
def process_images_inner_half_B(p: StableDiffusionProcessing, c, uc, prompts, seeds, subseeds):
def process_images_cond_to_image(p: StableDiffusionProcessing, c, uc, prompts, seeds, subseeds) -> Processed:
comments = {}
infotexts = []
output_images = []
@ -179,13 +189,57 @@ def process_images_inner_half_B(p: StableDiffusionProcessing, c, uc, prompts, se
# ↑↑↑ the above is modified from 'modules/processing.py' ↑↑↑
# ↓↓↓ the following is modified from 'modules/prompt_parser.py' ↓↓↓
from modules.prompt_parser import get_learned_conditioning_prompt_schedules, get_multicond_prompt_list
from modules.prompt_parser import ComposableScheduledPromptConditioning
def get_learned_conditioning(model, prompts, steps) -> Tuple[ScheduledPromptConditioning, T_tokens, T_weights]:
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules): # forced to be lengthed 1
cached = cache.get(prompt, None)
if cached is not None:
res.append(cached)
continue
texts = [x[1] for x in prompt_schedule]
conds, tokens, weights = LatentDiffusion_get_learned_conditioning(model, texts)
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
res.append(cond_schedule)
return res, tokens, weights
def get_multicond_learned_conditioning(model, prompts, steps) -> Tuple[MulticondLearnedConditioning, T_tokens, T_weights]:
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning, tokens, weights = get_learned_conditioning(model, prompt_flat_list, steps)
res = []
for indexes in res_indexes:
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res), tokens, weights
# ↑↑↑ the above is modified from 'modules/prompt_parser.py' ↑↑↑
# ↓↓↓ the following is modified from 'ldm.models.diffusion/ddpm.py' ↓↓↓
def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.Tensor:
#from ldm.models.diffusion import LatentDiffusion
# type(sd_model) == LatentDiffusion
from modules.sd_hijack import FrozenCLIPEmbedderWithCustomWords
#from ldm.models.diffusion import LatentDiffusion
#from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
# forward(self, x, c, *args, **kwargs)
def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.Tensor:
# forward(self:LatentDiffusion, x, c, *args, **kwargs)
self, x, c = sd_model, latent, cond
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() # [B=1]
@ -197,7 +251,7 @@ def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.T
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
# p_losses(self, x_start, cond, t, noise=None)
# p_losses(self:LatentDiffusion, x_start, cond, t, noise=None)
x_start, cond, t = x, c, t
noise = torch.randn_like(x_start) # [B=1, C=4, H=64, W=64]
@ -219,9 +273,125 @@ def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.T
return loss # [B=1, C=4, H=64, W=64]
def text_to_token(self:FrozenCLIPEmbedderWithCustomWords, text:List[str]) -> tuple:
# FrozenCLIPEmbedderWithCustomWords.FrozenCLIPEmbedder.CLIPTokenizer
with devices.autocast('cuda'):
if opts.use_old_emphasis_implementation:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def token_to_cond(self:FrozenCLIPEmbedderWithCustomWords, batch_multipliers:T_weights, remade_batch_tokens:T_tokens, used_custom_terms, hijack_comments, hijack_fixes) -> torch.Tensor:
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if opts.use_old_emphasis_implementation:
self.hijack.fixes = hijack_fixes
with torch.no_grad(), devices.autocast():
return self.process_tokens(remade_batch_tokens, batch_multipliers)
# allow length > 75
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
with torch.no_grad(), devices.autocast():
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def token_to_text(self:FrozenCLIPEmbedderWithCustomWords, tokens:T_tokens) -> str:
id_2_word = { v: k for k, v in self.wrapped.tokenizer.get_vocab().items() }
words = []
for tk in tokens[0]: # force B=1
w = id_2_word.get(tk, '<unk>')
if w in ['<|startoftext|>', '<|endoftext|>']: continue
if w.endswith('</w>'): w = w[:-4]
words.append(w)
return ' '.join(words)
def FrozenCLIPEmbedderWithCustomWords_forward(self:FrozenCLIPEmbedderWithCustomWords, text:List[str]) -> Tuple[torch.Tensor, T_tokens, T_weights]:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = text_to_token(self, text)
cond = token_to_cond(self, batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes)
return cond, remade_batch_tokens, batch_multipliers
def LatentDiffusion_get_learned_conditioning(sd_model, cond:List[str]) -> Tuple[torch.Tensor, T_tokens, T_weights]:
self, c = sd_model, cond
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if 'DiagonalGaussianDistribution' in str(type(c)):
c = c.mode()
else:
#c = self.cond_stage_model(c) # => goes this way, [B=1, T=77*n, D=768], [[]], [[]]
c, tokens, weights = FrozenCLIPEmbedderWithCustomWords_forward(self.cond_stage_model, c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c, tokens, weights
# ↑↑↑ the above is modified from 'ldm.models.diffusion/ddpm.py' ↑↑↑
def image_to_latent(model, img: Image) -> torch.Tensor:
#from ldm.models.diffusion import LatentDiffusion
# type(model) == LatentDiffusion
im = np.array(img).astype(np.uint8)
im = (im / 127.5 - 1.0).astype(np.float32)
x = torch.from_numpy(im)
x = torch.moveaxis(x, 2, 0)
x = x.unsqueeze(dim=0) # [B=1, C=3, H=512, W=512]
x = x.to(model.device)
latent = model.get_first_stage_encoding(model.encode_first_stage(x)) # [B=1, C=4, H=64, W=64]
return latent
def mlc_get_cond(c:MulticondLearnedConditioning) -> torch.Tensor:
return c.batch[0][0].schedules[0].cond # [B=1, T=77, D=768]
def mlc_replace_cond(c:MulticondLearnedConditioning, cond: torch.Tensor) -> MulticondLearnedConditioning:
r = deepcopy(c)
spc = r.batch[0][0].schedules[0]
r.batch[0][0].schedules[0] = ScheduledPromptConditioning(spc.end_at_step, cond)
return r
class Script(scripts.Script):
def title(self):
@ -239,10 +409,10 @@ class Script(scripts.Script):
replace_order = gr.Dropdown(label='Replace order (replace mode)', value=lambda: DEFAULT_REPLACE_ORDER, choices=CHOICES_REPLACE_ORDER)
grad_alpha = gr.Number (label='Step size (grad mode)', value=lambda: DEFAULT_GRAD_ALPHA)
grad_iter = gr.Number (label='Step count (grad mode)', value=lambda: DEFAULT_GRAD_ITERS, precision=0)
grad_iter = gr.Number (label='Step count (grad mode)', value=lambda: DEFAULT_GRAD_ITER, precision=0)
grad_meth = gr.Dropdown(label='Step method (grad mode)', value=lambda: DEFAULT_GRAD_METH, choices=CHOICES_GRAD_METH)
grad_w_latent = gr.Number (label='Loss for latent match (grad mode)', value=lambda: DEFAULT_GRAD_W_LATENT)
grad_w_cond = gr.Number (label='Loss for cond match (grad mode)', value=lambda: DEFAULT_GRAD_W_COND)
grad_w_latent = gr.Number (label='Weight for latent match (grad/replace-grad mode)', value=lambda: DEFAULT_GRAD_W_LATENT)
grad_w_cond = gr.Number (label='Weight for cond match (grad/replace-grad mode)', value=lambda: DEFAULT_GRAD_W_COND)
video_fps = gr.Number (label='Video FPS', value=lambda: DEFAULT_FPS)
show_debug = gr.Checkbox(label='Show verbose debug info at console', value=lambda: DEFAULT_DEBUG)
@ -253,11 +423,7 @@ class Script(scripts.Script):
video_fps, show_debug]
def get_next_sequence_number(path):
from pathlib import Path
"""
Determines and returns the next sequence number to use when saving an image in the specified directory.
The sequence starts at 0.
"""
""" Determines and returns the next sequence number to use when saving an image in the specified directory. The sequence starts at 0. """
result = -1
dir = Path(path)
for file in dir.iterdir():
@ -305,12 +471,12 @@ class Script(scripts.Script):
travel_path = os.path.join(p.outpath_samples, 'prompt_travel')
os.makedirs(travel_path, exist_ok=True)
travel_number = Script.get_next_sequence_number(travel_path)
travel_path = os.path.join(travel_path, f"{travel_number:05}")
p.outpath_samples = travel_path
os.makedirs(travel_path, exist_ok=True)
self.log_fp = os.path.join(travel_path, 'log.txt')
self.log_dp = os.path.join(travel_path, f"{travel_number:05}")
p.outpath_samples = self.log_dp
os.makedirs(self.log_dp, exist_ok=True)
self.log_fp = os.path.join(self.log_dp, 'log.txt')
# Force Batch Count and Batch Size to 1.
# Force Batch Count and Batch Size to 1
p.n_iter = 1
p.batch_size = 1
@ -326,15 +492,16 @@ class Script(scripts.Script):
# Implementation dispatcher
if mode == 'linear' : images, info = self.run_linear (p, pos_prompts, neg_prompts, steps, show_debug)
elif mode == 'replace': images, info = self.run_replace(p, pos_prompts, neg_prompts, steps, replace_order, show_debug)
elif mode == 'replace': images, info = self.run_replace(p, pos_prompts, neg_prompts, steps, replace_order, grad_w_latent, grad_w_cond, show_debug)
elif mode == 'grad' : images, info = self.run_grad (p, pos_prompts, neg_prompts, steps, grad_alpha, grad_iter, grad_meth, grad_w_latent, grad_w_cond, show_debug)
# Save video
if video_fps > 0 and len(images) > 1:
try:
clip = ImageSequenceClip([np.asarray(t) for t in images], fps=video_fps)
clip.write_videofile(os.path.join(travel_path, f"travel-{travel_number:05}.mp4"), verbose=False, audio=False)
except: pass
clip.write_videofile(os.path.join(self.log_dp, f"travel-{travel_number:05}.mp4"), verbose=False, audio=False)
except NameError: pass
except: print_exc()
return Processed(p, images, p.seed, info)
@ -343,26 +510,26 @@ class Script(scripts.Script):
initial_info = None
images = []
def weighted_sum(A, B, alpha, kind):
''' linear interpolate on latent space '''
def weighted_sum(A, B, alpha:float, kind:str) -> Union[ScheduledPromptConditioning, MulticondLearnedConditioning]:
''' linear interpolate on latent space of condition '''
C = deepcopy(A)
if kind == 'pos':
condA = A.batch[0][0].schedules[0].cond
condB = B.batch[0][0].schedules[0].cond
condC = (1 - alpha) * condA + alpha * condB
condC = (1 - alpha) * condA + (alpha) * condB
end_at_step = C.batch[0][0].schedules[0].end_at_step
C.batch[0][0].schedules[0] = ScheduledPromptConditioning(end_at_step, condC)
if kind == 'neg':
condA = A[0][0].cond
condB = B[0][0].cond
condC = (1 - alpha) * condA + alpha * condB
condC = (1 - alpha) * condA + (alpha) * condB
end_at_step = C[0][0].end_at_step
C[0][0] = ScheduledPromptConditioning(end_at_step, condC)
return C
def draw_by_cond(pos_hidden, neg_hidden, prompts, seeds, subseeds):
nonlocal images, initial_info, p
proc = process_images_inner_half_B(p, pos_hidden, neg_hidden, prompts, seeds, subseeds)
proc = process_images_cond_to_image(p, pos_hidden, neg_hidden, prompts, seeds, subseeds)
if initial_info is None: initial_info = proc.info
images += proc.images
@ -373,7 +540,7 @@ class Script(scripts.Script):
print(f' neg prompts: {neg_prompts[0]}')
p.prompt = pos_prompts[0]
p.negative_prompt = neg_prompts[0]
from_pos_hidden, from_neg_hidden, prompts, seeds, subseeds = process_images_inner_half_A(p)
from_pos_hidden, from_neg_hidden, prompts, seeds, subseeds = process_images_prompt_to_cond(p)
draw_by_cond(from_pos_hidden, from_neg_hidden, prompts, seeds, subseeds)
# travel through stages
@ -387,7 +554,7 @@ class Script(scripts.Script):
print(f' neg prompts: {neg_prompts[i]}')
p.prompt = pos_prompts[i]
p.negative_prompt = neg_prompts[i]
to_pos_hidden, to_neg_hidden, prompts, seeds, subseeds = process_images_inner_half_A(p)
to_pos_hidden, to_neg_hidden, prompts, seeds, subseeds = process_images_prompt_to_cond(p)
# Step 2: draw the interpolated images
n_inter = steps[i] + 1
@ -407,12 +574,144 @@ class Script(scripts.Script):
return images, initial_info
def run_replace(self, p:StableDiffusionProcessing, pos_prompts:List[str], neg_prompts:List[str], steps:List[int], replace_order:str, show_debug:bool):
def run_replace(self, p:StableDiffusionProcessing, pos_prompts:List[str], neg_prompts:List[str], steps:List[int], replace_order:str, grad_w_latent:float, grad_w_cond:float, show_debug:bool):
clip_model = p.sd_model.cond_stage_model
n_stages = len(steps)
initial_info = None
images = []
initial_info = '你先别急,这个还没有实现……'
# Step 1: draw init image
if show_debug: print(f'[stage 1/{n_stages}] prompts: {pos_prompts[0]}')
p.prompt = pos_prompts[0]
c, uc, prompts, seeds, subseeds, (c_tokens, c_weights, uc_tokens, uc_weights) = process_images_prompt_to_cond(p, ret_token_and_weight=True)
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
if initial_info is None: initial_info = proc.info
images += proc.images
# make log
log_fh = open(self.log_fp, 'w', encoding='utf-8')
log_fh.write(f'replace_order = {replace_order}\n')
log_fh.write('\n')
text_rev = token_to_text(clip_model, c_tokens)
log_fh.write(f'tokens: {text_rev}\n')
log_fh.write('\n')
# travel between stages
for i in range(1, n_stages):
if state.interrupted: break
# Step 2: draw the stage target
if show_debug: print(f'[stage {i+1}/{n_stages}] prompts: {pos_prompts[i]}')
p.prompt = pos_prompts[i]
*params, (tgt_c_tokens, tgt_c_weights, tgt_uc_tokens, tgt_uc_weights) = process_images_prompt_to_cond(p, ret_token_and_weight=True)
_, _, used_custom_terms, hijack_comments, hijack_fixes, _ = text_to_token(clip_model, [p.prompt])
proc = process_images_cond_to_image(p, *params)
if initial_info is None: initial_info = proc.info
target_image = proc.images[0] # cache it here to make video sequence order right
with torch.no_grad(), devices.autocast():
if replace_order.startswith('grad'):
target_latent = image_to_latent(p.sd_model, target_image) # [B=1, C=4, H=64, W=64]
target_cond = mlc_get_cond(params[0]).unsqueeze(0) # [B=1, T=77, D=768]
else:
embed_layer = clip_model.wrapped.transformer.get_input_embeddings() # transformers.models.clip.CLIPTextModel
source_embed = embed_layer(torch.LongTensor(c_tokens) .to(p.sd_model.device)) # [B=1, T=75, D=768]
target_embed = embed_layer(torch.LongTensor(tgt_c_tokens).to(p.sd_model.device))
L1_dist = F.l1_loss(source_embed, target_embed, reduction='none').squeeze(dim=0).mean(dim=-1).cpu().numpy() # [T=75]
# Step 3: draw the inter-mediums
for _ in range(steps[i]):
if state.interrupted: break
mask = np.asarray(c_tokens[0]) != np.asarray(tgt_c_tokens[0]) # [T=75]
n_replaces = sum(mask)
if n_replaces == 0: break
cnt = max(n_replaces // steps[i], 1)
if show_debug: print(f'need to replace {n_replaces} tokens, {cnt} tokens per travel step')
# token inverse
text_rev = token_to_text(clip_model, c_tokens)
log_fh.write(f'tokens: {text_rev}\n')
tokens = text_rev.split(' ')
n_tokens = len(tokens)
def _replace_tokens(sorted_indexes, c_tokens, tgt_c_tokens):
nonlocal mask, cnt
k, done = 0, 0
while done < cnt and sum(mask) > 0 and k < len(sorted_indexes):
idx = sorted_indexes[k]
if mask[idx]:
c_tokens[0][idx] = tgt_c_tokens[0][idx]
done += 1
k += 1
if replace_order.startswith('grad'):
with devices.autocast():
current_cond = mlc_get_cond(c).unsqueeze(0).clone() # [B=1, T=77, D=768]
current_cond .requires_grad = True
target_cond .requires_grad = True
target_latent.requires_grad = True
loss_latent = get_latent_loss(p.sd_model, target_latent, current_cond) # [B=1, C=4, H=64, W=64]
grad_latent = torch.autograd.grad(loss_latent, current_cond, grad_outputs=loss_latent)[0] # [B=1, T=77, D=768]
loss_cond = F.l1_loss(current_cond, target_cond, reduction='none') # [B=1, T=77, D=768]
grad_cond = torch.autograd.grad(loss_cond, current_cond, grad_outputs=loss_cond)[0] # [B=1, T=77, D=768]
grad = grad_latent * grad_w_latent + grad_cond * grad_w_cond # [B=1, T=77, D=768]
grad_trim = grad.squeeze(dim=0)[1:-1, :] # [T=75, D=768]
grad_token = grad_trim.mean(dim=1) # [T=75]
sorted_indexes_grad_ascending = grad_token.argsort().tolist()
if replace_order == 'grad_min':
sorted_indexes = sorted_indexes_grad_ascending
elif replace_order == 'grad_max':
sorted_indexes = sorted_indexes_grad_ascending[::-1]
_replace_tokens(sorted_indexes, c_tokens, tgt_c_tokens)
else:
sorted_indexes_L1_ascending = L1_dist.argsort()
if replace_order == 'random':
neq_indexes = [i for i, m in enumerate(mask) if m]
random.shuffle(neq_indexes)
_replace_tokens(neq_indexes, c_tokens, tgt_c_tokens)
else:
if replace_order == 'most_similar':
sorted_indexes = sorted_indexes_L1_ascending
elif replace_order == 'most_different':
sorted_indexes = sorted_indexes_L1_ascending[::-1]
_replace_tokens(sorted_indexes, c_tokens, tgt_c_tokens)
# log token importance (?
if replace_order.startswith('grad'):
tokens_grad_asc = [tokens[idx] for idx in sorted_indexes_grad_ascending if idx < len(tokens)]
log_fh.write(f' >> grad ascend: {" ".join(tokens_grad_asc)}\n')
else:
tokens_l1_asc = [tokens[idx] for idx in sorted_indexes_L1_ascending if idx < len(tokens)]
log_fh.write(f' >> embed L1-distance ascend: {" ".join(tokens_l1_asc)}\n')
log_fh.flush()
# move to new 'c' (one travel step!)
# FIXME: we do not walk on 'uc' so far
cond = token_to_cond(clip_model, c_weights, c_tokens, used_custom_terms, hijack_comments, hijack_fixes)
c = mlc_replace_cond(c, cond.detach().squeeze(0))
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
if initial_info is None: initial_info = proc.info
images += proc.images
log_fh.write(f'\n')
# append the finishing image for current stage
images += [target_image]
# shift: last stage's final info becomes new stage's init info
c, uc, prompts, seeds, subseeds = params
c_tokens, c_weights, uc_tokens, uc_weights = tgt_c_tokens, tgt_c_weights, tgt_uc_tokens, tgt_uc_weights
# save log
log_fh.close()
return images, initial_info
@ -421,34 +720,11 @@ class Script(scripts.Script):
initial_info = None
images = []
def image_to_latent(img: Image) -> torch.Tensor:
nonlocal p
model = p.sd_model
im = np.array(img).astype(np.uint8)
im = (im / 127.5 - 1.0).astype(np.float32)
x = torch.from_numpy(im)
x = torch.moveaxis(x, 2, 0)
x = x.unsqueeze(dim=0) # [B=1, C=3, H=512, W=512]
x = x.to(model.device)
latent = model.get_first_stage_encoding(model.encode_first_stage(x)) # [B=1, C=4, H=64, W=64]
return latent
def mlc_get_cond(c:MulticondLearnedConditioning) -> torch.Tensor:
return c.batch[0][0].schedules[0].cond # [B=1, T=77, D=768]
def mlc_replace_cond(c:MulticondLearnedConditioning, cond: torch.Tensor) -> MulticondLearnedConditioning:
r = deepcopy(c)
spc = r.batch[0][0].schedules[0]
r.batch[0][0].schedules[0] = ScheduledPromptConditioning(spc.end_at_step, cond)
return r
# Step 1: draw init image
if show_debug: print(f'[stage 1/{n_stages}] prompts: {pos_prompts[0]}')
p.prompt = pos_prompts[0]
c, uc, prompts, seeds, subseeds = process_images_inner_half_A(p)
proc = process_images_inner_half_B(p, c, uc, prompts, seeds, subseeds)
c, uc, prompts, seeds, subseeds = process_images_prompt_to_cond(p)
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
if initial_info is None: initial_info = proc.info
images += proc.images
@ -468,14 +744,14 @@ class Script(scripts.Script):
# Step 2: draw the stage target
if show_debug: print(f'[stage {i+1}/{n_stages}] prompts: {pos_prompts[i]}')
p.prompt = pos_prompts[i]
params = process_images_inner_half_A(p)
proc = process_images_inner_half_B(p, *params)
params = process_images_prompt_to_cond(p)
proc = process_images_cond_to_image(p, *params)
if initial_info is None: initial_info = proc.info
target_image = proc.images[0] # cache it here to make video sequence order right
with torch.no_grad(), devices.autocast():
target_latent = image_to_latent(target_image) # [B=1, C=4, H=64, W=64]
target_cond = mlc_get_cond(params[0]).unsqueeze(0) # [B=1, T=77, D=768]
target_latent = image_to_latent(p.sd_model, target_image) # [B=1, C=4, H=64, W=64]
target_cond = mlc_get_cond(params[0]).unsqueeze(0) # [B=1, T=77, D=768]
source_cond = mlc_get_cond(c).unsqueeze(0)
L1_dist = F.l1_loss(source_cond, target_cond).item()
@ -509,28 +785,27 @@ class Script(scripts.Script):
}
current_cond = current_cond.detach() - methods[grad_meth](grad) * grad_alpha
if show_debug:
with torch.no_grad():
l_latent = loss_latent.mean().item()
l_match = loss_cond .mean().item()
l_total = l_latent + l_match
L1_from = F.l1_loss(current_cond, source_cond).item()
L1_to = F.l1_loss(current_cond, target_cond).item() # FIXME: stop early when L1_to < grad_alpha / 2
grad_abs = grad.abs()
info = [
f'loss: {l_total} (l_grad: {l_latent}, l_match: {l_match})',
f' |grad|.avg: {grad_abs.mean()}, |grad|.max: {grad_abs.max()}',
f' L1 from src: {L1_from}, L1 to dst: {L1_to}, L1 total: {L1_dist}',
]
log_fh.write('\n'.join(info))
log_fh.write('\n')
log_fh.flush()
with torch.no_grad():
l_latent = loss_latent.mean().item()
l_match = loss_cond .mean().item()
l_total = l_latent + l_match
L1_from = F.l1_loss(current_cond, source_cond).item()
L1_to = F.l1_loss(current_cond, target_cond).item() # FIXME: stop early when L1_to < grad_alpha / 2
grad_abs = grad.abs()
info = [
f'loss: {l_total} (l_grad: {l_latent}, l_match: {l_match})',
f' |grad|.avg: {grad_abs.mean()}, |grad|.max: {grad_abs.max()}',
f' L1 from src: {L1_from}, L1 to dst: {L1_to}, L1 total: {L1_dist}',
]
log_fh.write('\n'.join(info))
log_fh.write('\n')
log_fh.flush()
# move to new 'c' (one travel step!)
# FIXME: we do not walk on 'uc' so far
c = mlc_replace_cond(c, current_cond.detach().squeeze(0))
proc = process_images_inner_half_B(p, c, uc, prompts, seeds, subseeds)
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
if initial_info is None: initial_info = proc.info
images += proc.images