parent
6be9c36936
commit
081697ff19
25
README.md
25
README.md
|
|
@ -8,7 +8,7 @@ This is the more human-sensible version of [stable-diffusion-webui-prompt-erosio
|
|||
now we do not modify on text char level, but do linear interpolating on the hidden embedded vectors. 😀
|
||||
|
||||
⚠ 我们成立了插件反馈 QQ 群: 616795645 (赤狐屿),欢迎出建议、意见、报告bug等 (w
|
||||
⚠ We have a QQ chat group now: 616795645, any suggeustion, discussion and bug reports are highly wellllcome !!
|
||||
⚠ We have a QQ chat group now: 616795645, any suggeustions, discussions and bug reports are highly wellllcome !!
|
||||
|
||||
ℹ 实话不说,我想有可能通过这个来做ppt童话绘本<del>甚至本子</del>……
|
||||
ℹ 聪明的用法:先手工盲搜两张好看的图 (只有prompt差异),然后再尝试在其间 travel :lolipop:
|
||||
|
|
@ -16,7 +16,7 @@ now we do not modify on text char level, but do linear interpolating on the hidd
|
|||
|
||||
### Change Log
|
||||
|
||||
- 2022/11/14: walk by substituting word embedding ('replace' mode)
|
||||
- 2022/11/14: walk by substituting token embedding ('replace' mode)
|
||||
- 2022/11/13: walk by optimizing condition ('grad' mode)
|
||||
- 2022/11/10: interpolate linearly on condition/uncondition ('linear' mode)
|
||||
|
||||
|
|
@ -47,17 +47,19 @@ now we do not modify on text char level, but do linear interpolating on the hidd
|
|||
- mode: (categorical)
|
||||
- linear: interpolate linearly on condition/uncondition in latent space
|
||||
- replace: walk by gradually substituting word embededings
|
||||
- grad: walk by optimizing certain loss (see [Experimental](#experimental))
|
||||
- NOTE: `walk` methods might not reach target stages in specified steps, manually tune `grad_alpha` or increase `steps` in that case accroding to log losses...
|
||||
- grad: walk by optimizing certain loss
|
||||
- NOTE: `walk` methods might not reach target stages in specified steps some times, or reached earlier than expect, in that case, manually tune `grad_alpha` and `steps` might help a little...
|
||||
- steps: (int, list of int)
|
||||
- number of images to interpolate between two successive stages<del>, set `-1` to allow wanderding util converge for `walk` methods (not yet implemented)</del>
|
||||
- number of images to interpolate between two successive stages
|
||||
- if int, constant number of travel steps
|
||||
- if list of int, length should match `len(stages)-1`, separate by comma, e.g.: `12, 24, 36`
|
||||
- replace_*
|
||||
- replace_order: (categorical)
|
||||
- `random`: substitute tokens randomly
|
||||
- `similiar`: substitute most similiar tokens first (L1 distance of token embeddings)
|
||||
- `different`: substitute most diffrent tokens first (L1 distance of token embeddings)
|
||||
- `different`: substitute most diffrent tokens first
|
||||
- `grad_min`: substitute tokens that causing smallest gradient first (gradient settings same as in `grad` mode)
|
||||
- `grad_max`: substitute tokens that causing largest gradient first
|
||||
- grad_*
|
||||
- grad_alpha: (float), step size of a walk pace
|
||||
- grad_iter: (int), step count of walk paces
|
||||
|
|
@ -68,9 +70,9 @@ now we do not modify on text char level, but do linear interpolating on the hidd
|
|||
- `sign`: walk at a constant speed (often stucks into oscillation at the end)
|
||||
- `tanh`: significantly speed down when approching (it takes infinite time to exactly reach...)
|
||||
- grad_w_latent: (float), weight factor of `loss_latent`
|
||||
- grad_w_match: (float), weight factor of `loss_cond`
|
||||
- grad_w_cond: (float), weight factor of `loss_cond`
|
||||
- fps: (float)
|
||||
- FPS of video, set 0 to disable saving
|
||||
- FPS of video, set `0` to disable saving
|
||||
- debug: (bool)
|
||||
- whether show verbose debug info at console
|
||||
|
||||
|
|
@ -117,6 +119,13 @@ Grid search results: (`steps=100, grad_alpha=0.01, grad_iter=1, grad_meth='clip'
|
|||
|
||||
ℹ NOTE: When 'prompt' has only single line, it will wander just **around** the init stage, dynamically balancing `loss_latent` and `loss_cond`; this allows you to discover neighbors of your given prompt 😀
|
||||
|
||||
⚪ 'replace' mode
|
||||
|
||||
This mode working on token embed input level, hence your can view `log.txt` to see how your input tokens are gradually changed.
|
||||
⚠ Remeber that comma is a normal valid token, so you might see many commas there. However, they are different when appearing at different positions within the token sequence.
|
||||
|
||||
The actual token replacing order might reveal some information of the token importances, might the listed '>> grad ascend' or '>> embed L1-distance ascend' give you some ideas to tune your input prompt (I wish so..)
|
||||
|
||||
----
|
||||
|
||||
by Armit
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import launch
|
||||
|
||||
if not launch.is_installed("moviepy"):
|
||||
launch.run_pip("install moviepy==1.0.3", "requirements for Seed Travel")
|
||||
launch.run_pip("install moviepy==1.0.3", "requirements for Prompt Travel to generate video")
|
||||
|
|
|
|||
|
|
@ -1,38 +1,44 @@
|
|||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Union
|
||||
from traceback import print_exc
|
||||
|
||||
import gradio as gr
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
try:
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
except ImportError:
|
||||
print(f"moviepy python module not installed. Will not be able to generate video.")
|
||||
try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
except ImportError: print(f"package moviepy not installed, will not be able to generate video")
|
||||
|
||||
import modules.scripts as scripts
|
||||
from modules.processing import Processed, StableDiffusionProcessing
|
||||
from modules.processing import *
|
||||
from modules.prompt_parser import ScheduledPromptConditioning, MulticondLearnedConditioning
|
||||
from modules.shared import state
|
||||
|
||||
DEFAULT_MODE = 'linear'
|
||||
DEFAULT_STEPS = 30
|
||||
DEFAULT_REPLACE_ORDER = 'random'
|
||||
DEFAULT_REPLACE_ORDER = 'grad_min'
|
||||
DEFAULT_GRAD_ALPHA = 0.01
|
||||
DEFAULT_GRAD_ITERS = 1
|
||||
DEFAULT_GRAD_ITER = 1
|
||||
DEFAULT_GRAD_METH = 'clip'
|
||||
DEFAULT_GRAD_W_LATENT = 1
|
||||
DEFAULT_GRAD_W_COND = 1
|
||||
DEFAULT_GRAD_W_LATENT = 1
|
||||
DEFAULT_GRAD_W_COND = 1
|
||||
DEFAULT_FPS = 10
|
||||
DEFAULT_DEBUG = True
|
||||
|
||||
CHOICES_MODE = ['linear', 'replace', 'grad']
|
||||
CHOICES_REPLACE_ORDER = ['random', 'similar', 'different']
|
||||
CHOICES_REPLACE_ORDER = ['random', 'most_similar', 'most_different', 'grad_min', 'grad_max']
|
||||
CHOICES_GRAD_METH = ['clip', 'sign', 'tanh']
|
||||
|
||||
T_tokens = List[List[float]]
|
||||
T_weights = List[List[int]]
|
||||
|
||||
# ↓↓↓ the following is modified from 'modules/processing.py' ↓↓↓
|
||||
|
||||
from modules.processing import apply_overlay, apply_color_correction, create_infotext, decode_first_stage, get_fixed_seed
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
|
@ -46,7 +52,7 @@ import modules.face_restoration
|
|||
import modules.images as images
|
||||
import modules.styles
|
||||
|
||||
def process_images_inner_half_A(p: StableDiffusionProcessing) -> tuple:
|
||||
def process_images_prompt_to_cond(p: StableDiffusionProcessing, ret_token_and_weight=False) -> tuple:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
assert p.prompt is not None
|
||||
|
|
@ -91,12 +97,16 @@ def process_images_inner_half_A(p: StableDiffusionProcessing) -> tuple:
|
|||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
with devices.autocast():
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||
# 'prompt string' => tensor([T, D])
|
||||
uc, uc_tokens, uc_weights = get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||
c, c_tokens, c_weights = get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||
|
||||
if ret_token_and_weight:
|
||||
return c, uc, prompts, seeds, subseeds, (c_tokens, c_weights, uc_tokens, uc_weights)
|
||||
else:
|
||||
return c, uc, prompts, seeds, subseeds
|
||||
|
||||
def process_images_inner_half_B(p: StableDiffusionProcessing, c, uc, prompts, seeds, subseeds):
|
||||
def process_images_cond_to_image(p: StableDiffusionProcessing, c, uc, prompts, seeds, subseeds) -> Processed:
|
||||
comments = {}
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
|
@ -179,13 +189,57 @@ def process_images_inner_half_B(p: StableDiffusionProcessing, c, uc, prompts, se
|
|||
# ↑↑↑ the above is modified from 'modules/processing.py' ↑↑↑
|
||||
|
||||
|
||||
# ↓↓↓ the following is modified from 'modules/prompt_parser.py' ↓↓↓
|
||||
|
||||
from modules.prompt_parser import get_learned_conditioning_prompt_schedules, get_multicond_prompt_list
|
||||
from modules.prompt_parser import ComposableScheduledPromptConditioning
|
||||
|
||||
def get_learned_conditioning(model, prompts, steps) -> Tuple[ScheduledPromptConditioning, T_tokens, T_weights]:
|
||||
res = []
|
||||
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||
cache = {}
|
||||
|
||||
for prompt, prompt_schedule in zip(prompts, prompt_schedules): # forced to be lengthed 1
|
||||
cached = cache.get(prompt, None)
|
||||
if cached is not None:
|
||||
res.append(cached)
|
||||
continue
|
||||
|
||||
texts = [x[1] for x in prompt_schedule]
|
||||
conds, tokens, weights = LatentDiffusion_get_learned_conditioning(model, texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
|
||||
return res, tokens, weights
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps) -> Tuple[MulticondLearnedConditioning, T_tokens, T_weights]:
|
||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||
|
||||
learned_conditioning, tokens, weights = get_learned_conditioning(model, prompt_flat_list, steps)
|
||||
|
||||
res = []
|
||||
for indexes in res_indexes:
|
||||
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
||||
|
||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res), tokens, weights
|
||||
|
||||
# ↑↑↑ the above is modified from 'modules/prompt_parser.py' ↑↑↑
|
||||
|
||||
|
||||
# ↓↓↓ the following is modified from 'ldm.models.diffusion/ddpm.py' ↓↓↓
|
||||
|
||||
def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.Tensor:
|
||||
#from ldm.models.diffusion import LatentDiffusion
|
||||
# type(sd_model) == LatentDiffusion
|
||||
from modules.sd_hijack import FrozenCLIPEmbedderWithCustomWords
|
||||
#from ldm.models.diffusion import LatentDiffusion
|
||||
#from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
# forward(self, x, c, *args, **kwargs)
|
||||
def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.Tensor:
|
||||
# forward(self:LatentDiffusion, x, c, *args, **kwargs)
|
||||
self, x, c = sd_model, latent, cond
|
||||
|
||||
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() # [B=1]
|
||||
|
|
@ -197,7 +251,7 @@ def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.T
|
|||
tc = self.cond_ids[t].to(self.device)
|
||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||
|
||||
# p_losses(self, x_start, cond, t, noise=None)
|
||||
# p_losses(self:LatentDiffusion, x_start, cond, t, noise=None)
|
||||
x_start, cond, t = x, c, t
|
||||
|
||||
noise = torch.randn_like(x_start) # [B=1, C=4, H=64, W=64]
|
||||
|
|
@ -219,9 +273,125 @@ def get_latent_loss(sd_model, latent:torch.Tensor, cond:torch.Tensor) -> torch.T
|
|||
|
||||
return loss # [B=1, C=4, H=64, W=64]
|
||||
|
||||
def text_to_token(self:FrozenCLIPEmbedderWithCustomWords, text:List[str]) -> tuple:
|
||||
# FrozenCLIPEmbedderWithCustomWords.FrozenCLIPEmbedder.CLIPTokenizer
|
||||
|
||||
with devices.autocast('cuda'):
|
||||
if opts.use_old_emphasis_implementation:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def token_to_cond(self:FrozenCLIPEmbedderWithCustomWords, batch_multipliers:T_weights, remade_batch_tokens:T_tokens, used_custom_terms, hijack_comments, hijack_fixes) -> torch.Tensor:
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if opts.use_old_emphasis_implementation:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
with torch.no_grad(), devices.autocast():
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
# allow length > 75
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def token_to_text(self:FrozenCLIPEmbedderWithCustomWords, tokens:T_tokens) -> str:
|
||||
id_2_word = { v: k for k, v in self.wrapped.tokenizer.get_vocab().items() }
|
||||
|
||||
words = []
|
||||
for tk in tokens[0]: # force B=1
|
||||
w = id_2_word.get(tk, '<unk>')
|
||||
if w in ['<|startoftext|>', '<|endoftext|>']: continue
|
||||
if w.endswith('</w>'): w = w[:-4]
|
||||
words.append(w)
|
||||
|
||||
return ' '.join(words)
|
||||
|
||||
def FrozenCLIPEmbedderWithCustomWords_forward(self:FrozenCLIPEmbedderWithCustomWords, text:List[str]) -> Tuple[torch.Tensor, T_tokens, T_weights]:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = text_to_token(self, text)
|
||||
cond = token_to_cond(self, batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes)
|
||||
|
||||
return cond, remade_batch_tokens, batch_multipliers
|
||||
|
||||
def LatentDiffusion_get_learned_conditioning(sd_model, cond:List[str]) -> Tuple[torch.Tensor, T_tokens, T_weights]:
|
||||
self, c = sd_model, cond
|
||||
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
||||
c = self.cond_stage_model.encode(c)
|
||||
if 'DiagonalGaussianDistribution' in str(type(c)):
|
||||
c = c.mode()
|
||||
else:
|
||||
#c = self.cond_stage_model(c) # => goes this way, [B=1, T=77*n, D=768], [[]], [[]]
|
||||
c, tokens, weights = FrozenCLIPEmbedderWithCustomWords_forward(self.cond_stage_model, c)
|
||||
else:
|
||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||
|
||||
return c, tokens, weights
|
||||
|
||||
# ↑↑↑ the above is modified from 'ldm.models.diffusion/ddpm.py' ↑↑↑
|
||||
|
||||
|
||||
def image_to_latent(model, img: Image) -> torch.Tensor:
|
||||
#from ldm.models.diffusion import LatentDiffusion
|
||||
# type(model) == LatentDiffusion
|
||||
|
||||
im = np.array(img).astype(np.uint8)
|
||||
im = (im / 127.5 - 1.0).astype(np.float32)
|
||||
x = torch.from_numpy(im)
|
||||
x = torch.moveaxis(x, 2, 0)
|
||||
x = x.unsqueeze(dim=0) # [B=1, C=3, H=512, W=512]
|
||||
x = x.to(model.device)
|
||||
|
||||
latent = model.get_first_stage_encoding(model.encode_first_stage(x)) # [B=1, C=4, H=64, W=64]
|
||||
return latent
|
||||
|
||||
def mlc_get_cond(c:MulticondLearnedConditioning) -> torch.Tensor:
|
||||
return c.batch[0][0].schedules[0].cond # [B=1, T=77, D=768]
|
||||
|
||||
def mlc_replace_cond(c:MulticondLearnedConditioning, cond: torch.Tensor) -> MulticondLearnedConditioning:
|
||||
r = deepcopy(c)
|
||||
spc = r.batch[0][0].schedules[0]
|
||||
r.batch[0][0].schedules[0] = ScheduledPromptConditioning(spc.end_at_step, cond)
|
||||
return r
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
|
||||
def title(self):
|
||||
|
|
@ -239,10 +409,10 @@ class Script(scripts.Script):
|
|||
|
||||
replace_order = gr.Dropdown(label='Replace order (replace mode)', value=lambda: DEFAULT_REPLACE_ORDER, choices=CHOICES_REPLACE_ORDER)
|
||||
grad_alpha = gr.Number (label='Step size (grad mode)', value=lambda: DEFAULT_GRAD_ALPHA)
|
||||
grad_iter = gr.Number (label='Step count (grad mode)', value=lambda: DEFAULT_GRAD_ITERS, precision=0)
|
||||
grad_iter = gr.Number (label='Step count (grad mode)', value=lambda: DEFAULT_GRAD_ITER, precision=0)
|
||||
grad_meth = gr.Dropdown(label='Step method (grad mode)', value=lambda: DEFAULT_GRAD_METH, choices=CHOICES_GRAD_METH)
|
||||
grad_w_latent = gr.Number (label='Loss for latent match (grad mode)', value=lambda: DEFAULT_GRAD_W_LATENT)
|
||||
grad_w_cond = gr.Number (label='Loss for cond match (grad mode)', value=lambda: DEFAULT_GRAD_W_COND)
|
||||
grad_w_latent = gr.Number (label='Weight for latent match (grad/replace-grad mode)', value=lambda: DEFAULT_GRAD_W_LATENT)
|
||||
grad_w_cond = gr.Number (label='Weight for cond match (grad/replace-grad mode)', value=lambda: DEFAULT_GRAD_W_COND)
|
||||
|
||||
video_fps = gr.Number (label='Video FPS', value=lambda: DEFAULT_FPS)
|
||||
show_debug = gr.Checkbox(label='Show verbose debug info at console', value=lambda: DEFAULT_DEBUG)
|
||||
|
|
@ -253,11 +423,7 @@ class Script(scripts.Script):
|
|||
video_fps, show_debug]
|
||||
|
||||
def get_next_sequence_number(path):
|
||||
from pathlib import Path
|
||||
"""
|
||||
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
||||
The sequence starts at 0.
|
||||
"""
|
||||
""" Determines and returns the next sequence number to use when saving an image in the specified directory. The sequence starts at 0. """
|
||||
result = -1
|
||||
dir = Path(path)
|
||||
for file in dir.iterdir():
|
||||
|
|
@ -305,12 +471,12 @@ class Script(scripts.Script):
|
|||
travel_path = os.path.join(p.outpath_samples, 'prompt_travel')
|
||||
os.makedirs(travel_path, exist_ok=True)
|
||||
travel_number = Script.get_next_sequence_number(travel_path)
|
||||
travel_path = os.path.join(travel_path, f"{travel_number:05}")
|
||||
p.outpath_samples = travel_path
|
||||
os.makedirs(travel_path, exist_ok=True)
|
||||
self.log_fp = os.path.join(travel_path, 'log.txt')
|
||||
self.log_dp = os.path.join(travel_path, f"{travel_number:05}")
|
||||
p.outpath_samples = self.log_dp
|
||||
os.makedirs(self.log_dp, exist_ok=True)
|
||||
self.log_fp = os.path.join(self.log_dp, 'log.txt')
|
||||
|
||||
# Force Batch Count and Batch Size to 1.
|
||||
# Force Batch Count and Batch Size to 1
|
||||
p.n_iter = 1
|
||||
p.batch_size = 1
|
||||
|
||||
|
|
@ -326,15 +492,16 @@ class Script(scripts.Script):
|
|||
|
||||
# Implementation dispatcher
|
||||
if mode == 'linear' : images, info = self.run_linear (p, pos_prompts, neg_prompts, steps, show_debug)
|
||||
elif mode == 'replace': images, info = self.run_replace(p, pos_prompts, neg_prompts, steps, replace_order, show_debug)
|
||||
elif mode == 'replace': images, info = self.run_replace(p, pos_prompts, neg_prompts, steps, replace_order, grad_w_latent, grad_w_cond, show_debug)
|
||||
elif mode == 'grad' : images, info = self.run_grad (p, pos_prompts, neg_prompts, steps, grad_alpha, grad_iter, grad_meth, grad_w_latent, grad_w_cond, show_debug)
|
||||
|
||||
# Save video
|
||||
if video_fps > 0 and len(images) > 1:
|
||||
try:
|
||||
clip = ImageSequenceClip([np.asarray(t) for t in images], fps=video_fps)
|
||||
clip.write_videofile(os.path.join(travel_path, f"travel-{travel_number:05}.mp4"), verbose=False, audio=False)
|
||||
except: pass
|
||||
clip.write_videofile(os.path.join(self.log_dp, f"travel-{travel_number:05}.mp4"), verbose=False, audio=False)
|
||||
except NameError: pass
|
||||
except: print_exc()
|
||||
|
||||
return Processed(p, images, p.seed, info)
|
||||
|
||||
|
|
@ -343,26 +510,26 @@ class Script(scripts.Script):
|
|||
initial_info = None
|
||||
images = []
|
||||
|
||||
def weighted_sum(A, B, alpha, kind):
|
||||
''' linear interpolate on latent space '''
|
||||
def weighted_sum(A, B, alpha:float, kind:str) -> Union[ScheduledPromptConditioning, MulticondLearnedConditioning]:
|
||||
''' linear interpolate on latent space of condition '''
|
||||
C = deepcopy(A)
|
||||
if kind == 'pos':
|
||||
condA = A.batch[0][0].schedules[0].cond
|
||||
condB = B.batch[0][0].schedules[0].cond
|
||||
condC = (1 - alpha) * condA + alpha * condB
|
||||
condC = (1 - alpha) * condA + (alpha) * condB
|
||||
end_at_step = C.batch[0][0].schedules[0].end_at_step
|
||||
C.batch[0][0].schedules[0] = ScheduledPromptConditioning(end_at_step, condC)
|
||||
if kind == 'neg':
|
||||
condA = A[0][0].cond
|
||||
condB = B[0][0].cond
|
||||
condC = (1 - alpha) * condA + alpha * condB
|
||||
condC = (1 - alpha) * condA + (alpha) * condB
|
||||
end_at_step = C[0][0].end_at_step
|
||||
C[0][0] = ScheduledPromptConditioning(end_at_step, condC)
|
||||
return C
|
||||
|
||||
def draw_by_cond(pos_hidden, neg_hidden, prompts, seeds, subseeds):
|
||||
nonlocal images, initial_info, p
|
||||
proc = process_images_inner_half_B(p, pos_hidden, neg_hidden, prompts, seeds, subseeds)
|
||||
proc = process_images_cond_to_image(p, pos_hidden, neg_hidden, prompts, seeds, subseeds)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
images += proc.images
|
||||
|
||||
|
|
@ -373,7 +540,7 @@ class Script(scripts.Script):
|
|||
print(f' neg prompts: {neg_prompts[0]}')
|
||||
p.prompt = pos_prompts[0]
|
||||
p.negative_prompt = neg_prompts[0]
|
||||
from_pos_hidden, from_neg_hidden, prompts, seeds, subseeds = process_images_inner_half_A(p)
|
||||
from_pos_hidden, from_neg_hidden, prompts, seeds, subseeds = process_images_prompt_to_cond(p)
|
||||
draw_by_cond(from_pos_hidden, from_neg_hidden, prompts, seeds, subseeds)
|
||||
|
||||
# travel through stages
|
||||
|
|
@ -387,7 +554,7 @@ class Script(scripts.Script):
|
|||
print(f' neg prompts: {neg_prompts[i]}')
|
||||
p.prompt = pos_prompts[i]
|
||||
p.negative_prompt = neg_prompts[i]
|
||||
to_pos_hidden, to_neg_hidden, prompts, seeds, subseeds = process_images_inner_half_A(p)
|
||||
to_pos_hidden, to_neg_hidden, prompts, seeds, subseeds = process_images_prompt_to_cond(p)
|
||||
|
||||
# Step 2: draw the interpolated images
|
||||
n_inter = steps[i] + 1
|
||||
|
|
@ -407,12 +574,144 @@ class Script(scripts.Script):
|
|||
|
||||
return images, initial_info
|
||||
|
||||
def run_replace(self, p:StableDiffusionProcessing, pos_prompts:List[str], neg_prompts:List[str], steps:List[int], replace_order:str, show_debug:bool):
|
||||
def run_replace(self, p:StableDiffusionProcessing, pos_prompts:List[str], neg_prompts:List[str], steps:List[int], replace_order:str, grad_w_latent:float, grad_w_cond:float, show_debug:bool):
|
||||
clip_model = p.sd_model.cond_stage_model
|
||||
n_stages = len(steps)
|
||||
initial_info = None
|
||||
images = []
|
||||
|
||||
initial_info = '你先别急,这个还没有实现……'
|
||||
# Step 1: draw init image
|
||||
if show_debug: print(f'[stage 1/{n_stages}] prompts: {pos_prompts[0]}')
|
||||
p.prompt = pos_prompts[0]
|
||||
c, uc, prompts, seeds, subseeds, (c_tokens, c_weights, uc_tokens, uc_weights) = process_images_prompt_to_cond(p, ret_token_and_weight=True)
|
||||
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
images += proc.images
|
||||
|
||||
# make log
|
||||
log_fh = open(self.log_fp, 'w', encoding='utf-8')
|
||||
log_fh.write(f'replace_order = {replace_order}\n')
|
||||
log_fh.write('\n')
|
||||
|
||||
text_rev = token_to_text(clip_model, c_tokens)
|
||||
log_fh.write(f'tokens: {text_rev}\n')
|
||||
log_fh.write('\n')
|
||||
|
||||
# travel between stages
|
||||
for i in range(1, n_stages):
|
||||
if state.interrupted: break
|
||||
|
||||
# Step 2: draw the stage target
|
||||
if show_debug: print(f'[stage {i+1}/{n_stages}] prompts: {pos_prompts[i]}')
|
||||
p.prompt = pos_prompts[i]
|
||||
*params, (tgt_c_tokens, tgt_c_weights, tgt_uc_tokens, tgt_uc_weights) = process_images_prompt_to_cond(p, ret_token_and_weight=True)
|
||||
_, _, used_custom_terms, hijack_comments, hijack_fixes, _ = text_to_token(clip_model, [p.prompt])
|
||||
proc = process_images_cond_to_image(p, *params)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
target_image = proc.images[0] # cache it here to make video sequence order right
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
if replace_order.startswith('grad'):
|
||||
target_latent = image_to_latent(p.sd_model, target_image) # [B=1, C=4, H=64, W=64]
|
||||
target_cond = mlc_get_cond(params[0]).unsqueeze(0) # [B=1, T=77, D=768]
|
||||
else:
|
||||
embed_layer = clip_model.wrapped.transformer.get_input_embeddings() # transformers.models.clip.CLIPTextModel
|
||||
source_embed = embed_layer(torch.LongTensor(c_tokens) .to(p.sd_model.device)) # [B=1, T=75, D=768]
|
||||
target_embed = embed_layer(torch.LongTensor(tgt_c_tokens).to(p.sd_model.device))
|
||||
L1_dist = F.l1_loss(source_embed, target_embed, reduction='none').squeeze(dim=0).mean(dim=-1).cpu().numpy() # [T=75]
|
||||
|
||||
# Step 3: draw the inter-mediums
|
||||
for _ in range(steps[i]):
|
||||
if state.interrupted: break
|
||||
|
||||
mask = np.asarray(c_tokens[0]) != np.asarray(tgt_c_tokens[0]) # [T=75]
|
||||
n_replaces = sum(mask)
|
||||
if n_replaces == 0: break
|
||||
cnt = max(n_replaces // steps[i], 1)
|
||||
if show_debug: print(f'need to replace {n_replaces} tokens, {cnt} tokens per travel step')
|
||||
|
||||
# token inverse
|
||||
text_rev = token_to_text(clip_model, c_tokens)
|
||||
log_fh.write(f'tokens: {text_rev}\n')
|
||||
tokens = text_rev.split(' ')
|
||||
n_tokens = len(tokens)
|
||||
|
||||
def _replace_tokens(sorted_indexes, c_tokens, tgt_c_tokens):
|
||||
nonlocal mask, cnt
|
||||
k, done = 0, 0
|
||||
while done < cnt and sum(mask) > 0 and k < len(sorted_indexes):
|
||||
idx = sorted_indexes[k]
|
||||
if mask[idx]:
|
||||
c_tokens[0][idx] = tgt_c_tokens[0][idx]
|
||||
done += 1
|
||||
k += 1
|
||||
|
||||
if replace_order.startswith('grad'):
|
||||
with devices.autocast():
|
||||
current_cond = mlc_get_cond(c).unsqueeze(0).clone() # [B=1, T=77, D=768]
|
||||
|
||||
current_cond .requires_grad = True
|
||||
target_cond .requires_grad = True
|
||||
target_latent.requires_grad = True
|
||||
|
||||
loss_latent = get_latent_loss(p.sd_model, target_latent, current_cond) # [B=1, C=4, H=64, W=64]
|
||||
grad_latent = torch.autograd.grad(loss_latent, current_cond, grad_outputs=loss_latent)[0] # [B=1, T=77, D=768]
|
||||
loss_cond = F.l1_loss(current_cond, target_cond, reduction='none') # [B=1, T=77, D=768]
|
||||
grad_cond = torch.autograd.grad(loss_cond, current_cond, grad_outputs=loss_cond)[0] # [B=1, T=77, D=768]
|
||||
grad = grad_latent * grad_w_latent + grad_cond * grad_w_cond # [B=1, T=77, D=768]
|
||||
|
||||
grad_trim = grad.squeeze(dim=0)[1:-1, :] # [T=75, D=768]
|
||||
grad_token = grad_trim.mean(dim=1) # [T=75]
|
||||
|
||||
sorted_indexes_grad_ascending = grad_token.argsort().tolist()
|
||||
if replace_order == 'grad_min':
|
||||
sorted_indexes = sorted_indexes_grad_ascending
|
||||
elif replace_order == 'grad_max':
|
||||
sorted_indexes = sorted_indexes_grad_ascending[::-1]
|
||||
_replace_tokens(sorted_indexes, c_tokens, tgt_c_tokens)
|
||||
|
||||
else:
|
||||
sorted_indexes_L1_ascending = L1_dist.argsort()
|
||||
if replace_order == 'random':
|
||||
neq_indexes = [i for i, m in enumerate(mask) if m]
|
||||
random.shuffle(neq_indexes)
|
||||
_replace_tokens(neq_indexes, c_tokens, tgt_c_tokens)
|
||||
else:
|
||||
if replace_order == 'most_similar':
|
||||
sorted_indexes = sorted_indexes_L1_ascending
|
||||
elif replace_order == 'most_different':
|
||||
sorted_indexes = sorted_indexes_L1_ascending[::-1]
|
||||
_replace_tokens(sorted_indexes, c_tokens, tgt_c_tokens)
|
||||
|
||||
# log token importance (?
|
||||
if replace_order.startswith('grad'):
|
||||
tokens_grad_asc = [tokens[idx] for idx in sorted_indexes_grad_ascending if idx < len(tokens)]
|
||||
log_fh.write(f' >> grad ascend: {" ".join(tokens_grad_asc)}\n')
|
||||
else:
|
||||
tokens_l1_asc = [tokens[idx] for idx in sorted_indexes_L1_ascending if idx < len(tokens)]
|
||||
log_fh.write(f' >> embed L1-distance ascend: {" ".join(tokens_l1_asc)}\n')
|
||||
log_fh.flush()
|
||||
|
||||
# move to new 'c' (one travel step!)
|
||||
# FIXME: we do not walk on 'uc' so far
|
||||
cond = token_to_cond(clip_model, c_weights, c_tokens, used_custom_terms, hijack_comments, hijack_fixes)
|
||||
c = mlc_replace_cond(c, cond.detach().squeeze(0))
|
||||
|
||||
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
images += proc.images
|
||||
|
||||
log_fh.write(f'\n')
|
||||
|
||||
# append the finishing image for current stage
|
||||
images += [target_image]
|
||||
|
||||
# shift: last stage's final info becomes new stage's init info
|
||||
c, uc, prompts, seeds, subseeds = params
|
||||
c_tokens, c_weights, uc_tokens, uc_weights = tgt_c_tokens, tgt_c_weights, tgt_uc_tokens, tgt_uc_weights
|
||||
|
||||
# save log
|
||||
log_fh.close()
|
||||
|
||||
return images, initial_info
|
||||
|
||||
|
|
@ -421,34 +720,11 @@ class Script(scripts.Script):
|
|||
initial_info = None
|
||||
images = []
|
||||
|
||||
def image_to_latent(img: Image) -> torch.Tensor:
|
||||
nonlocal p
|
||||
model = p.sd_model
|
||||
|
||||
im = np.array(img).astype(np.uint8)
|
||||
im = (im / 127.5 - 1.0).astype(np.float32)
|
||||
x = torch.from_numpy(im)
|
||||
x = torch.moveaxis(x, 2, 0)
|
||||
x = x.unsqueeze(dim=0) # [B=1, C=3, H=512, W=512]
|
||||
x = x.to(model.device)
|
||||
|
||||
latent = model.get_first_stage_encoding(model.encode_first_stage(x)) # [B=1, C=4, H=64, W=64]
|
||||
return latent
|
||||
|
||||
def mlc_get_cond(c:MulticondLearnedConditioning) -> torch.Tensor:
|
||||
return c.batch[0][0].schedules[0].cond # [B=1, T=77, D=768]
|
||||
|
||||
def mlc_replace_cond(c:MulticondLearnedConditioning, cond: torch.Tensor) -> MulticondLearnedConditioning:
|
||||
r = deepcopy(c)
|
||||
spc = r.batch[0][0].schedules[0]
|
||||
r.batch[0][0].schedules[0] = ScheduledPromptConditioning(spc.end_at_step, cond)
|
||||
return r
|
||||
|
||||
# Step 1: draw init image
|
||||
if show_debug: print(f'[stage 1/{n_stages}] prompts: {pos_prompts[0]}')
|
||||
p.prompt = pos_prompts[0]
|
||||
c, uc, prompts, seeds, subseeds = process_images_inner_half_A(p)
|
||||
proc = process_images_inner_half_B(p, c, uc, prompts, seeds, subseeds)
|
||||
c, uc, prompts, seeds, subseeds = process_images_prompt_to_cond(p)
|
||||
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
images += proc.images
|
||||
|
||||
|
|
@ -468,14 +744,14 @@ class Script(scripts.Script):
|
|||
# Step 2: draw the stage target
|
||||
if show_debug: print(f'[stage {i+1}/{n_stages}] prompts: {pos_prompts[i]}')
|
||||
p.prompt = pos_prompts[i]
|
||||
params = process_images_inner_half_A(p)
|
||||
proc = process_images_inner_half_B(p, *params)
|
||||
params = process_images_prompt_to_cond(p)
|
||||
proc = process_images_cond_to_image(p, *params)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
target_image = proc.images[0] # cache it here to make video sequence order right
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
target_latent = image_to_latent(target_image) # [B=1, C=4, H=64, W=64]
|
||||
target_cond = mlc_get_cond(params[0]).unsqueeze(0) # [B=1, T=77, D=768]
|
||||
target_latent = image_to_latent(p.sd_model, target_image) # [B=1, C=4, H=64, W=64]
|
||||
target_cond = mlc_get_cond(params[0]).unsqueeze(0) # [B=1, T=77, D=768]
|
||||
source_cond = mlc_get_cond(c).unsqueeze(0)
|
||||
L1_dist = F.l1_loss(source_cond, target_cond).item()
|
||||
|
||||
|
|
@ -509,28 +785,27 @@ class Script(scripts.Script):
|
|||
}
|
||||
current_cond = current_cond.detach() - methods[grad_meth](grad) * grad_alpha
|
||||
|
||||
if show_debug:
|
||||
with torch.no_grad():
|
||||
l_latent = loss_latent.mean().item()
|
||||
l_match = loss_cond .mean().item()
|
||||
l_total = l_latent + l_match
|
||||
L1_from = F.l1_loss(current_cond, source_cond).item()
|
||||
L1_to = F.l1_loss(current_cond, target_cond).item() # FIXME: stop early when L1_to < grad_alpha / 2
|
||||
grad_abs = grad.abs()
|
||||
info = [
|
||||
f'loss: {l_total} (l_grad: {l_latent}, l_match: {l_match})',
|
||||
f' |grad|.avg: {grad_abs.mean()}, |grad|.max: {grad_abs.max()}',
|
||||
f' L1 from src: {L1_from}, L1 to dst: {L1_to}, L1 total: {L1_dist}',
|
||||
]
|
||||
log_fh.write('\n'.join(info))
|
||||
log_fh.write('\n')
|
||||
log_fh.flush()
|
||||
with torch.no_grad():
|
||||
l_latent = loss_latent.mean().item()
|
||||
l_match = loss_cond .mean().item()
|
||||
l_total = l_latent + l_match
|
||||
L1_from = F.l1_loss(current_cond, source_cond).item()
|
||||
L1_to = F.l1_loss(current_cond, target_cond).item() # FIXME: stop early when L1_to < grad_alpha / 2
|
||||
grad_abs = grad.abs()
|
||||
info = [
|
||||
f'loss: {l_total} (l_grad: {l_latent}, l_match: {l_match})',
|
||||
f' |grad|.avg: {grad_abs.mean()}, |grad|.max: {grad_abs.max()}',
|
||||
f' L1 from src: {L1_from}, L1 to dst: {L1_to}, L1 total: {L1_dist}',
|
||||
]
|
||||
log_fh.write('\n'.join(info))
|
||||
log_fh.write('\n')
|
||||
log_fh.flush()
|
||||
|
||||
# move to new 'c' (one travel step!)
|
||||
# FIXME: we do not walk on 'uc' so far
|
||||
c = mlc_replace_cond(c, current_cond.detach().squeeze(0))
|
||||
|
||||
proc = process_images_inner_half_B(p, c, uc, prompts, seeds, subseeds)
|
||||
proc = process_images_cond_to_image(p, c, uc, prompts, seeds, subseeds)
|
||||
if initial_info is None: initial_info = proc.info
|
||||
images += proc.images
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue