add SDXL (limited) support, notify #304, #283

pull/356/head
Kahsolt 2023-10-17 13:53:53 +08:00
parent 4a6e3312d5
commit eaf93da6c3
7 changed files with 205 additions and 143 deletions

View File

@ -25,6 +25,7 @@ The extension enables **large image drawing & upscaling with limited VRAM** via
## Features
- [x] SDXL model support (no region control)
- [x] [StableSR support](https://github.com/pkuliyi2015/sd-webui-stablesr)
- [x] [Tiled Noise Inversion](#🆕-tiled-noise-inversion)
- [x] [Tiled VAE](#🔥-tiled-vae)

View File

@ -68,7 +68,7 @@ from modules.shared import opts
from modules.processing import opt_f, get_fixed_seed
from modules.ui import gr_show
from tile_methods.abstractdiffusion import TiledDiffusion
from tile_methods.abstractdiffusion import AbstractDiffusion
from tile_methods.multidiffusion import MultiDiffusion
from tile_methods.mixtureofdiffusers import MixtureOfDiffusers
from tile_utils.utils import *
@ -83,7 +83,7 @@ class Script(scripts.Script):
def __init__(self):
self.controlnet_script: ModuleType = None
self.stablesr_script: ModuleType = None
self.delegate: TiledDiffusion = None
self.delegate: AbstractDiffusion = None
self.noise_inverse_cache: NoiseInverseCache = None
def title(self):

View File

@ -2,20 +2,22 @@ import math
import torch
from types import MethodType
import inspect
import k_diffusion as K
import torch.nn.functional as F
from tqdm import tqdm
from modules import devices, shared, sd_samplers_common
from modules.shared import state, cmd_opts
from modules.shared import state
from modules.shared_state import State
state: State
from modules.processing import opt_f
from tile_utils.utils import *
from tile_utils.typing import *
class TiledDiffusion:
class AbstractDiffusion:
def __init__(self, p: Processing, sampler: Sampler):
self.method = self.__class__.__name__
@ -33,14 +35,6 @@ class TiledDiffusion:
self.is_edit_model = (shared.sd_model.cond_stage_key == "edit" # "txt"
and self.sampler.image_cfg_scale is not None
and self.sampler.image_cfg_scale != 1.0)
# img conditioning for different models (inpaint/unclip model)
if shared.sd_model.model.conditioning_key == "crossattn-adm":
self.make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
self.get_image_cond = lambda c_in: c_in['c_adm']
else:
self.make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
self.get_image_cond = lambda c_in: c_in['c_concat'][0]
# cache. final result of current sampling step, [B, C=4, H//8, W//8]
# avoiding overhead of creating new tensors and weight summing
@ -50,7 +44,7 @@ class TiledDiffusion:
# weights for background & grid bboxes
self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
# ME: I'm trying to count the step correctly but it's not working
# FIXME: I'm trying to count the step correctly but it's not working
self.step_count = 0
self.inner_loop_count = 0
self.kdiff_step = -1
@ -137,6 +131,49 @@ class TiledDiffusion:
self.pbar = tqdm(total=(self.total_bboxes) * state.sampling_steps, desc=f"{self.method} Sampling: ")
''' ↓↓↓ cond_dict utils ↓↓↓ '''
def _tcond_key(self, cond_dict:CondDict) -> str:
return 'crossattn' if 'crossattn' in cond_dict else 'c_crossattn'
def get_tcond(self, cond_dict:CondDict) -> Tensor:
tcond = cond_dict[self._tcond_key(cond_dict)]
if isinstance(tcond, list): tcond = tcond[0]
return tcond
def set_tcond(self, cond_dict:CondDict, tcond:Tensor):
key = self._tcond_key(cond_dict)
if isinstance(cond_dict[key], list): tcond = [tcond]
cond_dict[key] = tcond
def _icond_key(self, cond_dict:CondDict) -> str:
return 'c_adm' if shared.sd_model.model.conditioning_key in ['crossattn-adm', 'adm'] else 'c_concat'
def get_icond(self, cond_dict:CondDict) -> Tensor:
''' icond differs for different models (inpaint/unclip model) '''
key = self._icond_key(cond_dict)
icond = cond_dict[key]
if isinstance(icond, list): icond = icond[0]
return icond
def set_icond(self, cond_dict:CondDict, icond:Tensor):
key = self._icond_key(cond_dict)
if isinstance(cond_dict[key], list): icond = [icond]
cond_dict[key] = icond
def _vcond_key(self, cond_dict:CondDict) -> Optional[str]:
return 'vector' if 'vector' in cond_dict else None
def get_vcond(self, cond_dict:CondDict) -> Optional[Tensor]:
''' vector for SDXL '''
key = self._vcond_key(cond_dict)
return cond_dict.get(key)
def set_vcond(self, cond_dict:CondDict, vcond:Optional[Tensor]):
key = self._vcond_key(cond_dict)
if key is not None:
cond_dict[key] = vcond
''' ↓↓↓ extensive functionality ↓↓↓ '''
@grid_bbox
@ -505,7 +542,8 @@ class TiledDiffusion:
for param_id in range(len(self.control_params)):
control_tensor = self.control_tensor_custom[param_id][bbox_id].to(devices.device)
self.control_params[param_id].hint_cond = control_tensor.repeat((repeat_size, 1, 1, 1))
@stablesr
def init_stablesr(self, stablesr_script:ModuleType):
if stablesr_script.stablesr_model is None: return
@ -535,7 +573,6 @@ class TiledDiffusion:
if self.stablesr_script.stablesr_model is None: return
self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor
@stablesr
def switch_stablesr_tensors(self, batch_id:int):
if not self.enable_stablesr: return

View File

@ -2,13 +2,15 @@ import torch
from modules import devices, shared, extra_networks
from modules.shared import state
from modules.shared_state import State
state: State
from tile_methods.abstractdiffusion import TiledDiffusion
from tile_methods.abstractdiffusion import AbstractDiffusion
from tile_utils.utils import *
from tile_utils.typing import *
class MixtureOfDiffusers(TiledDiffusion):
class MixtureOfDiffusers(AbstractDiffusion):
"""
Mixture-of-Diffusers Implementation
https://github.com/albarji/mixture-of-diffusers
@ -62,36 +64,21 @@ class MixtureOfDiffusers(TiledDiffusion):
''' ↓↓↓ kernel hijacks ↓↓↓ '''
def custom_apply_model(self, x_in, t_in, c_in, bbox_id, bbox) -> Tensor:
if self.is_kdiff:
return self.kdiff_custom_forward(x_in, t_in, c_in, bbox_id, bbox, forward_func=shared.sd_model.apply_model_original_md)
else:
def forward_func(x, c, ts, unconditional_conditioning, *args, **kwargs) -> Tensor:
# copy from p_sample_ddim in ddim.py
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
self.set_custom_controlnet_tensors(bbox_id, x.shape[0])
self.set_custom_stablesr_tensors(bbox_id)
return shared.sd_model.apply_model_original_md(x, ts, c_in)
return self.ddim_custom_forward(x_in, c_in, bbox, ts=t_in, forward_func=forward_func)
@torch.no_grad()
@keep_signature
def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict, noise_inverse_step=-1):
def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict, noise_inverse_step:int=-1):
assert LatentDiffusion.apply_model
# KDiffusion Compatibility
c_in = cond
# KDiffusion Compatibility for naming
c_in: CondDict = cond
N, C, H, W = x_in.shape
if H != self.h or W != self.w:
if (H, W) != (self.h, self.w):
# We don't tile highres, let's just use the original apply_model
self.reset_controlnet_tensors()
return shared.sd_model.apply_model_original_md(x_in, t_in, c_in)
# clear buffer canvas
self.reset_buffer(x_in)
# Global sampling
@ -102,24 +89,36 @@ class MixtureOfDiffusers(TiledDiffusion):
# batching
x_tile_list = []
t_tile_list = []
attn_tile_list = []
image_cond_list = []
tcond_tile_list = []
icond_tile_list = []
vcond_tile_list = []
for bbox in bboxes:
x_tile_list.append(x_in[bbox.slicer])
t_tile_list.append(t_in)
if c_in is not None and isinstance(c_in, dict):
image_cond = self.get_image_cond(c_in)
# dummy for txt2img, latent mask for img2img
if image_cond.shape[2] == self.h and image_cond.shape[3] == self.w:
image_cond = image_cond[bbox.slicer]
image_cond_list.append(image_cond)
attn_tile = c_in['c_crossattn'][0] # cond, [1, 77, 768]
attn_tile_list.append(attn_tile)
x_tile = torch.cat(x_tile_list, dim=0) # differs each
t_tile = torch.cat(t_tile_list, dim=0) # just repeat
attn_tile = torch.cat(attn_tile_list, dim=0) # just repeat
image_cond_tile = torch.cat(image_cond_list, dim=0) # differs each
c_tile = self.make_condition_dict([attn_tile], image_cond_tile)
if isinstance(c_in, dict):
# tcond
tcond_tile = self.get_tcond(c_in) # cond, [1, 77, 768]
tcond_tile_list.append(tcond_tile)
# icond: might be dummy for txt2img, latent mask for img2img
icond = self.get_icond(c_in)
if icond.shape[2:] == (self.h, self.w):
icond = icond[bbox.slicer]
icond_tile_list.append(icond)
# vcond:
vcond = self.get_vcond(c_in)
vcond_tile_list.append(vcond)
else:
print('>> [WARN] not supported, make an issue on github!!')
x_tile = torch.cat(x_tile_list, dim=0) # differs each
t_tile = torch.cat(t_tile_list, dim=0) # just repeat
tcond_tile = torch.cat(tcond_tile_list, dim=0) # just repeat
icond_tile = torch.cat(icond_tile_list, dim=0) # differs each
vcond_tile = torch.cat(vcond_tile_list, dim=0) if None not in vcond_tile_list else None # just repeat
c_tile: CondDict = c_in.copy()
self.set_tcond(c_tile, tcond_tile) # [1, 77, 768]
self.set_icond(c_tile, icond_tile) # [1, 5, 1, 1]
self.set_vcond(c_tile, vcond_tile) # [1, ?]
# controlnet
self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True)
@ -153,13 +152,16 @@ class MixtureOfDiffusers(TiledDiffusion):
if noise_inverse_step < 0:
x_tile_out = self.custom_apply_model(x_tile, t_in, c_in, bbox_id, bbox)
else:
custom_cond = Condition.reconstruct_cond(bbox.cond, noise_inverse_step)
image_cond = self.get_image_cond(c_in)
if image_cond.shape[2:] == (self.h, self.w):
image_cond = image_cond[bbox.slicer]
image_conditioning = image_cond
custom_cond_in = self.make_condition_dict([custom_cond], image_conditioning)
x_tile_out = shared.sd_model.apply_model(x_tile, t_in, cond=custom_cond_in)
c_out: CondDict = c_in.copy()
tcond = Condition.reconstruct_cond(bbox.cond, noise_inverse_step)
self.set_tcond(c_out, tcond)
icond = self.get_icond(c_in)
if icond.shape[2:] == (self.h, self.w):
icond = icond[bbox.slicer]
self.set_icond(c_out, icond)
vcond = self.get_vcond(c_in)
self.set_vcond(c_out, vcond)
x_tile_out = shared.sd_model.apply_model(x_tile, t_in, cond=c_out)
if bbox.blend_mode == BlendMode.BACKGROUND:
self.x_buffer[bbox.slicer] += x_tile_out * self.custom_weights[bbox_id]
@ -188,9 +190,25 @@ class MixtureOfDiffusers(TiledDiffusion):
# For mixture of diffusers, we cannot fill the not denoised area.
# So we just leave it as it is.
return x_out
def custom_apply_model(self, x_in, t_in, c_in, bbox_id, bbox) -> Tensor:
if self.is_kdiff:
return self.kdiff_custom_forward(x_in, t_in, c_in, bbox_id, bbox, forward_func=shared.sd_model.apply_model_original_md)
else:
def forward_func(x, c, ts, unconditional_conditioning, *args, **kwargs) -> Tensor:
# copy from p_sample_ddim in ddim.py
c_in: CondDict = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
self.set_custom_controlnet_tensors(bbox_id, x.shape[0])
self.set_custom_stablesr_tensors(bbox_id)
return shared.sd_model.apply_model_original_md(x, ts, c_in)
return self.ddim_custom_forward(x_in, c_in, bbox, ts=t_in, forward_func=forward_func)
@torch.no_grad()
def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor:
return self.apply_model_hijack(x_in, sigma_in, cond=cond_in, noise_inverse_step=step)
return self.apply_model_hijack(x_in, sigma_in, cond=cond_in, noise_inverse_step=step)

View File

@ -2,13 +2,15 @@ import torch
from modules import devices, extra_networks
from modules.shared import state
from modules.shared_state import State
state: State
from tile_methods.abstractdiffusion import TiledDiffusion
from tile_methods.abstractdiffusion import AbstractDiffusion
from tile_utils.utils import *
from tile_utils.typing import *
class MultiDiffusion(TiledDiffusion):
class MultiDiffusion(AbstractDiffusion):
"""
Multi-Diffusion Implementation
https://arxiv.org/abs/2302.08113
@ -30,8 +32,12 @@ class MultiDiffusion(TiledDiffusion):
self.sampler.inner_model.forward = self.kdiff_forward
else:
self.sampler: VanillaStableDiffusionSampler
self.sampler_forward = self.sampler.orig_p_sample_ddim
self.sampler.orig_p_sample_ddim = self.ddim_forward
if isinstance(self.p, ProcessingImg2Img):
self.sampler_forward = self.sampler.sample_img2img
self.sampler.sample_img2img = self.ddim_forward
else:
self.sampler_forward = self.sampler.sample
self.sampler.sample = self.ddim_forward
@staticmethod
def unhook():
@ -59,81 +65,51 @@ class MultiDiffusion(TiledDiffusion):
''' ↓↓↓ kernel hijacks ↓↓↓ '''
def repeat_cond_dict(self, cond_input:CondDict, bboxes:List[CustomBBox]) -> CondDict:
cond = cond_input['c_crossattn'][0]
# repeat the condition on its first dim
cond_shape = cond.shape
cond = cond.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1))
image_cond = self.get_image_cond(cond_input)
if image_cond.shape[2] == self.h and image_cond.shape[3] == self.w:
image_cond_list = []
for bbox in bboxes:
image_cond_list.append(image_cond[bbox.slicer])
image_cond_tile = torch.cat(image_cond_list, dim=0)
else:
image_cond_shape = image_cond.shape
image_cond_tile = image_cond.repeat((len(bboxes),) + (1,) * (len(image_cond_shape) - 1))
return self.make_condition_dict([cond], image_cond_tile)
@torch.no_grad()
@keep_signature
def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict) -> Tensor:
'''
This function hijacks `k_diffusion.external.CompVisDenoiser.forward()`
So its signature should be the same as the original function, especially the "cond" should be with exactly the same name
'''
assert CompVisDenoiser.forward
def org_func(x:Tensor):
def org_func(x:Tensor) -> Tensor:
return self.sampler_forward(x, sigma_in, cond=cond)
def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]):
def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor:
# For kdiff sampler, the dim 0 of input x_in is:
# = batch_size * (num_AND + 1) if not an edit model
# = batch_size * (num_AND + 2) otherwise
sigma_in_tile = sigma_in.repeat(len(bboxes))
new_cond = self.repeat_cond_dict(cond, bboxes)
x_tile_out = self.sampler_forward(x_tile, sigma_in_tile, cond=new_cond)
sigma_tile = self._rep_dim0(sigma_in, len(bboxes))
cond_tile = self.repeat_cond_dict(cond, bboxes)
x_tile_out = self.sampler_forward(x_tile, sigma_tile, cond=cond_tile)
return x_tile_out
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox):
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor:
return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward)
return self.sample_one_step(x_in, org_func, repeat_func, custom_func)
@torch.no_grad()
@keep_signature
def ddim_forward(self, x_in:Tensor, cond_in:Union[CondDict, Tensor], ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
'''
This function will replace the original p_sample_ddim function in ldm/diffusionmodels/ddim.py
So its signature should be the same as the original function,
Particularly, the unconditional_conditioning should be with exactly the same name
'''
def ddim_forward(self, p:Processing, x_in:Tensor, cond_in:Union[CondDict, Tensor], ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
assert VanillaStableDiffusionSampler.sample
assert VanillaStableDiffusionSampler.sample_img2img
assert VanillaStableDiffusionSampler.p_sample_ddim_hook
def org_func(x:Tensor) -> Tensor:
return self.sampler_forward(x, p, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
def org_func(x:Tensor):
return self.sampler_forward(x, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]):
if isinstance(cond_in, dict):
ts_tile = ts.repeat(len(bboxes))
def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]:
n_rep = len(bboxes)
if isinstance(cond_in, dict): # FIXME: when will enter this branch?
ts_tile = self._rep_dim0(ts, n_rep)
cond_tile = self.repeat_cond_dict(cond_in, bboxes)
ucond_tile = self.repeat_cond_dict(unconditional_conditioning, bboxes)
else:
ts_tile = ts.repeat(len(bboxes))
cond_shape = cond_in.shape
cond_tile = cond_in.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1))
ucond_shape = unconditional_conditioning.shape
ucond_tile = unconditional_conditioning.repeat((len(bboxes),) + (1,) * (len(ucond_shape) - 1))
x_tile_out, x_pred = self.sampler_forward(
x_tile, cond_tile, ts_tile,
unconditional_conditioning=ucond_tile,
*args, **kwargs)
ts_tile = self._rep_dim0(ts, n_rep)
cond_tile = self._rep_dim0(cond_in, n_rep)
ucond_tile = self._rep_dim0(unconditional_conditioning, n_rep)
x_tile_out, x_pred = self.sampler_forward(x_tile, cond_tile, ts_tile, unconditional_conditioning=ucond_tile, *args, **kwargs)
return x_tile_out, x_pred
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox):
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor:
# before the final forward, we can set the control tensor
def forward_func(x, *args, **kwargs):
self.set_custom_controlnet_tensors(bbox_id, 2*x.shape[0])
@ -143,17 +119,48 @@ class MultiDiffusion(TiledDiffusion):
return self.sample_one_step(x_in, org_func, repeat_func, custom_func)
def sample_one_step(self, x_in:Tensor, org_func: Callable, repeat_func:Callable, custom_func:Callable) -> Union[Tensor, Tuple[Tensor, Tensor]]:
def _rep_dim0(self, x:Tensor, n:int) -> Tensor:
''' repeat the tensor on it's first dim '''
if n == 1: return x
shape = [n] + [-1] * (len(x.shape) - 1) # [N, 1, ...]
return x.expand(shape) # `expand` is much lighter than `tile`
def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict:
''' repeat cond_dict for a batch of tiles '''
# n_repeat
breakpoint()
n_rep = len(bboxes)
cond_out = cond_in.copy()
# txt cond
tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D]
tcond = self._rep_dim0(tcond, n_rep)
self.set_tcond(cond_out, tcond)
# img cond
icond = self.get_icond(cond_in)
if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W]
icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0)
else: # txt2img, [B=1, C=5, H=1, W=1]
icond = self._rep_dim0(icond, n_rep)
self.set_icond(cond_out, icond)
# vec cond (SDXL)
vcond = self.get_vcond(cond_in) # [B=1, D]
if vcond is not None:
vcond = self._rep_dim0(vcond, n_rep) # [B*N, D]
self.set_vcond(cond_out, vcond)
return cond_out
def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Union[Tensor, Tuple[Tensor, Tensor]]:
'''
this method splits the whole latent and process in tiles
- x_in: current whole U-Net latent
- org_func: original forward function, when use highres
- denoise_func: one step denoiser for grid tile
- denoise_custom_func: one step denoiser for custom tile
- repeat_func: one step denoiser for grid tile
- custom_func: one step denoiser for custom tile
'''
N, C, H, W = x_in.shape
if H != self.h or W != self.w:
if (H, W) != (self.h, self.w):
# We don't tile highres, let's just use the original org_func
self.reset_controlnet_tensors()
return org_func(x_in)
@ -166,13 +173,10 @@ class MultiDiffusion(TiledDiffusion):
if state.interrupted: return x_in
# batching
x_tile_list = []
for bbox in bboxes:
x_tile_list.append(x_in[bbox.slicer])
x_tile = torch.cat(x_tile_list, dim=0)
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW]
# controlnet tiling
# FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers
# FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers, why?
self.switch_controlnet_tensors(batch_id, N, len(bboxes))
# stablesr tiling
@ -262,30 +266,30 @@ class MultiDiffusion(TiledDiffusion):
x_pred_out = torch.where(x_feather_count > 0, x_pred_out * (1 - x_feather_mask) + x_feather_pred_buffer * x_feather_mask, x_pred_out)
return x_out if self.is_kdiff else (x_out, x_pred_out)
def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor:
# NOTE: The following code is analytically wrong but aesthetically beautiful
local_cond_in = cond_in.copy()
cond_in_original = cond_in.copy()
def org_func(x:Tensor):
return shared.sd_model.apply_model(x, sigma_in, cond=local_cond_in)
return shared.sd_model.apply_model(x, sigma_in, cond=cond_in_original)
def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]):
sigma_in_tile = sigma_in.repeat(len(bboxes))
new_cond = self.repeat_cond_dict(local_cond_in, bboxes)
x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=new_cond)
cond_out = self.repeat_cond_dict(cond_in_original, bboxes)
x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out)
return x_tile_out
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox):
# The negative prompt in custom bbox should not be used for noise inversion
# otherwise the result will be astonishingly bad.
cond = Condition.reconstruct_cond(bbox.cond, step)
image_cond = self.get_image_cond(local_cond_in)
if image_cond.shape[2:] == (self.h, self.w):
image_cond = image_cond[bbox.slicer]
image_conditioning = image_cond
self.make_condition_dict([cond], image_conditioning)
return shared.sd_model.apply_model(x, sigma_in, cond=cond_in)
cond_out: CondDict = cond_in.copy()
tcond = Condition.reconstruct_cond(bbox.cond, step).unsqueeze_(0)
self.set_tcond(cond_out, tcond)
icond = self.get_icond(cond_in_original)
if icond.shape[2:] == (self.h, self.w):
icond = icond[bbox.slicer]
self.set_icond(cond_out, icond)
return shared.sd_model.apply_model(x, sigma_in, cond=cond_out)
return self.sample_one_step(x_in, org_func, repeat_func, custom_func)

View File

@ -11,7 +11,7 @@ from modules.processing import StableDiffusionProcessing as Processing, StableDi
from modules.prompt_parser import MulticondLearnedConditioning, ScheduledPromptConditioning
from modules.extra_networks import ExtraNetworkParams
from modules.sd_samplers_kdiffusion import KDiffusionSampler, CFGDenoiser
from modules.sd_samplers_compvis import VanillaStableDiffusionSampler
from modules.sd_samplers_timesteps import VanillaStableDiffusionSampler
ModuleType = type(sys)
@ -20,7 +20,9 @@ Cond = MulticondLearnedConditioning
Uncond = List[List[ScheduledPromptConditioning]]
ExtraNetworkData = DefaultDict[str, List[ExtraNetworkParams]]
# 'c_crossattn': Tensor # prompt cond
# 'c_concat': Tensor # latent mask
# 'c_adm': Tensor # unclip
# 'c_crossattn' List[Tensor[B, L=77, D=768]] prompt cond (tcond)
# 'c_concat' List[Tensor[B, C=5, H, W]] latent mask (icond)
# 'c_adm' Tensor[?] unclip (icond)
# 'crossattn' Tensor[B, L=77, D=2048] sdxl (tcond)
# 'vector' Tensor[B, D] sdxl (tcond)
CondDict = Dict[str, Tensor]