diff --git a/README.md b/README.md index 1f1cf71..450d219 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ The extension enables **large image drawing & upscaling with limited VRAM** via ## Features +- [x] SDXL model support (no region control) - [x] [StableSR support](https://github.com/pkuliyi2015/sd-webui-stablesr) - [x] [Tiled Noise Inversion](#🆕-tiled-noise-inversion) - [x] [Tiled VAE](#🔥-tiled-vae) diff --git a/scripts/tilediffusion.py b/scripts/tilediffusion.py index d7c3aec..657e4f4 100644 --- a/scripts/tilediffusion.py +++ b/scripts/tilediffusion.py @@ -68,7 +68,7 @@ from modules.shared import opts from modules.processing import opt_f, get_fixed_seed from modules.ui import gr_show -from tile_methods.abstractdiffusion import TiledDiffusion +from tile_methods.abstractdiffusion import AbstractDiffusion from tile_methods.multidiffusion import MultiDiffusion from tile_methods.mixtureofdiffusers import MixtureOfDiffusers from tile_utils.utils import * @@ -83,7 +83,7 @@ class Script(scripts.Script): def __init__(self): self.controlnet_script: ModuleType = None self.stablesr_script: ModuleType = None - self.delegate: TiledDiffusion = None + self.delegate: AbstractDiffusion = None self.noise_inverse_cache: NoiseInverseCache = None def title(self): diff --git a/scripts/vae_optimize.py b/scripts/tilevae.py similarity index 100% rename from scripts/vae_optimize.py rename to scripts/tilevae.py diff --git a/tile_methods/abstractdiffusion.py b/tile_methods/abstractdiffusion.py index 4216a89..96d484a 100644 --- a/tile_methods/abstractdiffusion.py +++ b/tile_methods/abstractdiffusion.py @@ -2,20 +2,22 @@ import math import torch from types import MethodType -import inspect import k_diffusion as K import torch.nn.functional as F from tqdm import tqdm from modules import devices, shared, sd_samplers_common -from modules.shared import state, cmd_opts +from modules.shared import state +from modules.shared_state import State +state: State from modules.processing import opt_f from tile_utils.utils import * from tile_utils.typing import * -class TiledDiffusion: + +class AbstractDiffusion: def __init__(self, p: Processing, sampler: Sampler): self.method = self.__class__.__name__ @@ -33,14 +35,6 @@ class TiledDiffusion: self.is_edit_model = (shared.sd_model.cond_stage_key == "edit" # "txt" and self.sampler.image_cfg_scale is not None and self.sampler.image_cfg_scale != 1.0) - - # img conditioning for different models (inpaint/unclip model) - if shared.sd_model.model.conditioning_key == "crossattn-adm": - self.make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} - self.get_image_cond = lambda c_in: c_in['c_adm'] - else: - self.make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} - self.get_image_cond = lambda c_in: c_in['c_concat'][0] # cache. final result of current sampling step, [B, C=4, H//8, W//8] # avoiding overhead of creating new tensors and weight summing @@ -50,7 +44,7 @@ class TiledDiffusion: # weights for background & grid bboxes self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32) - # ME: I'm trying to count the step correctly but it's not working + # FIXME: I'm trying to count the step correctly but it's not working self.step_count = 0 self.inner_loop_count = 0 self.kdiff_step = -1 @@ -137,6 +131,49 @@ class TiledDiffusion: self.pbar = tqdm(total=(self.total_bboxes) * state.sampling_steps, desc=f"{self.method} Sampling: ") + ''' ↓↓↓ cond_dict utils ↓↓↓ ''' + + def _tcond_key(self, cond_dict:CondDict) -> str: + return 'crossattn' if 'crossattn' in cond_dict else 'c_crossattn' + + def get_tcond(self, cond_dict:CondDict) -> Tensor: + tcond = cond_dict[self._tcond_key(cond_dict)] + if isinstance(tcond, list): tcond = tcond[0] + return tcond + + def set_tcond(self, cond_dict:CondDict, tcond:Tensor): + key = self._tcond_key(cond_dict) + if isinstance(cond_dict[key], list): tcond = [tcond] + cond_dict[key] = tcond + + def _icond_key(self, cond_dict:CondDict) -> str: + return 'c_adm' if shared.sd_model.model.conditioning_key in ['crossattn-adm', 'adm'] else 'c_concat' + + def get_icond(self, cond_dict:CondDict) -> Tensor: + ''' icond differs for different models (inpaint/unclip model) ''' + key = self._icond_key(cond_dict) + icond = cond_dict[key] + if isinstance(icond, list): icond = icond[0] + return icond + + def set_icond(self, cond_dict:CondDict, icond:Tensor): + key = self._icond_key(cond_dict) + if isinstance(cond_dict[key], list): icond = [icond] + cond_dict[key] = icond + + def _vcond_key(self, cond_dict:CondDict) -> Optional[str]: + return 'vector' if 'vector' in cond_dict else None + + def get_vcond(self, cond_dict:CondDict) -> Optional[Tensor]: + ''' vector for SDXL ''' + key = self._vcond_key(cond_dict) + return cond_dict.get(key) + + def set_vcond(self, cond_dict:CondDict, vcond:Optional[Tensor]): + key = self._vcond_key(cond_dict) + if key is not None: + cond_dict[key] = vcond + ''' ↓↓↓ extensive functionality ↓↓↓ ''' @grid_bbox @@ -505,7 +542,8 @@ class TiledDiffusion: for param_id in range(len(self.control_params)): control_tensor = self.control_tensor_custom[param_id][bbox_id].to(devices.device) self.control_params[param_id].hint_cond = control_tensor.repeat((repeat_size, 1, 1, 1)) - + + @stablesr def init_stablesr(self, stablesr_script:ModuleType): if stablesr_script.stablesr_model is None: return @@ -535,7 +573,6 @@ class TiledDiffusion: if self.stablesr_script.stablesr_model is None: return self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor - @stablesr def switch_stablesr_tensors(self, batch_id:int): if not self.enable_stablesr: return diff --git a/tile_methods/mixtureofdiffusers.py b/tile_methods/mixtureofdiffusers.py index a494239..8ff2361 100644 --- a/tile_methods/mixtureofdiffusers.py +++ b/tile_methods/mixtureofdiffusers.py @@ -2,13 +2,15 @@ import torch from modules import devices, shared, extra_networks from modules.shared import state +from modules.shared_state import State +state: State -from tile_methods.abstractdiffusion import TiledDiffusion +from tile_methods.abstractdiffusion import AbstractDiffusion from tile_utils.utils import * from tile_utils.typing import * -class MixtureOfDiffusers(TiledDiffusion): +class MixtureOfDiffusers(AbstractDiffusion): """ Mixture-of-Diffusers Implementation https://github.com/albarji/mixture-of-diffusers @@ -62,36 +64,21 @@ class MixtureOfDiffusers(TiledDiffusion): ''' ↓↓↓ kernel hijacks ↓↓↓ ''' - def custom_apply_model(self, x_in, t_in, c_in, bbox_id, bbox) -> Tensor: - if self.is_kdiff: - return self.kdiff_custom_forward(x_in, t_in, c_in, bbox_id, bbox, forward_func=shared.sd_model.apply_model_original_md) - else: - def forward_func(x, c, ts, unconditional_conditioning, *args, **kwargs) -> Tensor: - # copy from p_sample_ddim in ddim.py - c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] - else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - self.set_custom_controlnet_tensors(bbox_id, x.shape[0]) - self.set_custom_stablesr_tensors(bbox_id) - return shared.sd_model.apply_model_original_md(x, ts, c_in) - return self.ddim_custom_forward(x_in, c_in, bbox, ts=t_in, forward_func=forward_func) - @torch.no_grad() @keep_signature - def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict, noise_inverse_step=-1): + def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict, noise_inverse_step:int=-1): assert LatentDiffusion.apply_model - # KDiffusion Compatibility - c_in = cond + # KDiffusion Compatibility for naming + c_in: CondDict = cond + N, C, H, W = x_in.shape - if H != self.h or W != self.w: + if (H, W) != (self.h, self.w): # We don't tile highres, let's just use the original apply_model self.reset_controlnet_tensors() return shared.sd_model.apply_model_original_md(x_in, t_in, c_in) + # clear buffer canvas self.reset_buffer(x_in) # Global sampling @@ -102,24 +89,36 @@ class MixtureOfDiffusers(TiledDiffusion): # batching x_tile_list = [] t_tile_list = [] - attn_tile_list = [] - image_cond_list = [] + tcond_tile_list = [] + icond_tile_list = [] + vcond_tile_list = [] for bbox in bboxes: x_tile_list.append(x_in[bbox.slicer]) t_tile_list.append(t_in) - if c_in is not None and isinstance(c_in, dict): - image_cond = self.get_image_cond(c_in) - # dummy for txt2img, latent mask for img2img - if image_cond.shape[2] == self.h and image_cond.shape[3] == self.w: - image_cond = image_cond[bbox.slicer] - image_cond_list.append(image_cond) - attn_tile = c_in['c_crossattn'][0] # cond, [1, 77, 768] - attn_tile_list.append(attn_tile) - x_tile = torch.cat(x_tile_list, dim=0) # differs each - t_tile = torch.cat(t_tile_list, dim=0) # just repeat - attn_tile = torch.cat(attn_tile_list, dim=0) # just repeat - image_cond_tile = torch.cat(image_cond_list, dim=0) # differs each - c_tile = self.make_condition_dict([attn_tile], image_cond_tile) + if isinstance(c_in, dict): + # tcond + tcond_tile = self.get_tcond(c_in) # cond, [1, 77, 768] + tcond_tile_list.append(tcond_tile) + # icond: might be dummy for txt2img, latent mask for img2img + icond = self.get_icond(c_in) + if icond.shape[2:] == (self.h, self.w): + icond = icond[bbox.slicer] + icond_tile_list.append(icond) + # vcond: + vcond = self.get_vcond(c_in) + vcond_tile_list.append(vcond) + else: + print('>> [WARN] not supported, make an issue on github!!') + x_tile = torch.cat(x_tile_list, dim=0) # differs each + t_tile = torch.cat(t_tile_list, dim=0) # just repeat + tcond_tile = torch.cat(tcond_tile_list, dim=0) # just repeat + icond_tile = torch.cat(icond_tile_list, dim=0) # differs each + vcond_tile = torch.cat(vcond_tile_list, dim=0) if None not in vcond_tile_list else None # just repeat + + c_tile: CondDict = c_in.copy() + self.set_tcond(c_tile, tcond_tile) # [1, 77, 768] + self.set_icond(c_tile, icond_tile) # [1, 5, 1, 1] + self.set_vcond(c_tile, vcond_tile) # [1, ?] # controlnet self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True) @@ -153,13 +152,16 @@ class MixtureOfDiffusers(TiledDiffusion): if noise_inverse_step < 0: x_tile_out = self.custom_apply_model(x_tile, t_in, c_in, bbox_id, bbox) else: - custom_cond = Condition.reconstruct_cond(bbox.cond, noise_inverse_step) - image_cond = self.get_image_cond(c_in) - if image_cond.shape[2:] == (self.h, self.w): - image_cond = image_cond[bbox.slicer] - image_conditioning = image_cond - custom_cond_in = self.make_condition_dict([custom_cond], image_conditioning) - x_tile_out = shared.sd_model.apply_model(x_tile, t_in, cond=custom_cond_in) + c_out: CondDict = c_in.copy() + tcond = Condition.reconstruct_cond(bbox.cond, noise_inverse_step) + self.set_tcond(c_out, tcond) + icond = self.get_icond(c_in) + if icond.shape[2:] == (self.h, self.w): + icond = icond[bbox.slicer] + self.set_icond(c_out, icond) + vcond = self.get_vcond(c_in) + self.set_vcond(c_out, vcond) + x_tile_out = shared.sd_model.apply_model(x_tile, t_in, cond=c_out) if bbox.blend_mode == BlendMode.BACKGROUND: self.x_buffer[bbox.slicer] += x_tile_out * self.custom_weights[bbox_id] @@ -188,9 +190,25 @@ class MixtureOfDiffusers(TiledDiffusion): # For mixture of diffusers, we cannot fill the not denoised area. # So we just leave it as it is. - return x_out + def custom_apply_model(self, x_in, t_in, c_in, bbox_id, bbox) -> Tensor: + if self.is_kdiff: + return self.kdiff_custom_forward(x_in, t_in, c_in, bbox_id, bbox, forward_func=shared.sd_model.apply_model_original_md) + else: + def forward_func(x, c, ts, unconditional_conditioning, *args, **kwargs) -> Tensor: + # copy from p_sample_ddim in ddim.py + c_in: CondDict = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + self.set_custom_controlnet_tensors(bbox_id, x.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return shared.sd_model.apply_model_original_md(x, ts, c_in) + return self.ddim_custom_forward(x_in, c_in, bbox, ts=t_in, forward_func=forward_func) + @torch.no_grad() def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: - return self.apply_model_hijack(x_in, sigma_in, cond=cond_in, noise_inverse_step=step) \ No newline at end of file + return self.apply_model_hijack(x_in, sigma_in, cond=cond_in, noise_inverse_step=step) diff --git a/tile_methods/multidiffusion.py b/tile_methods/multidiffusion.py index 727ed1c..67c390d 100644 --- a/tile_methods/multidiffusion.py +++ b/tile_methods/multidiffusion.py @@ -2,13 +2,15 @@ import torch from modules import devices, extra_networks from modules.shared import state +from modules.shared_state import State +state: State -from tile_methods.abstractdiffusion import TiledDiffusion +from tile_methods.abstractdiffusion import AbstractDiffusion from tile_utils.utils import * from tile_utils.typing import * -class MultiDiffusion(TiledDiffusion): +class MultiDiffusion(AbstractDiffusion): """ Multi-Diffusion Implementation https://arxiv.org/abs/2302.08113 @@ -30,8 +32,12 @@ class MultiDiffusion(TiledDiffusion): self.sampler.inner_model.forward = self.kdiff_forward else: self.sampler: VanillaStableDiffusionSampler - self.sampler_forward = self.sampler.orig_p_sample_ddim - self.sampler.orig_p_sample_ddim = self.ddim_forward + if isinstance(self.p, ProcessingImg2Img): + self.sampler_forward = self.sampler.sample_img2img + self.sampler.sample_img2img = self.ddim_forward + else: + self.sampler_forward = self.sampler.sample + self.sampler.sample = self.ddim_forward @staticmethod def unhook(): @@ -59,81 +65,51 @@ class MultiDiffusion(TiledDiffusion): ''' ↓↓↓ kernel hijacks ↓↓↓ ''' - def repeat_cond_dict(self, cond_input:CondDict, bboxes:List[CustomBBox]) -> CondDict: - cond = cond_input['c_crossattn'][0] - # repeat the condition on its first dim - cond_shape = cond.shape - cond = cond.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1)) - image_cond = self.get_image_cond(cond_input) - if image_cond.shape[2] == self.h and image_cond.shape[3] == self.w: - image_cond_list = [] - for bbox in bboxes: - image_cond_list.append(image_cond[bbox.slicer]) - image_cond_tile = torch.cat(image_cond_list, dim=0) - else: - image_cond_shape = image_cond.shape - image_cond_tile = image_cond.repeat((len(bboxes),) + (1,) * (len(image_cond_shape) - 1)) - return self.make_condition_dict([cond], image_cond_tile) - @torch.no_grad() @keep_signature def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict) -> Tensor: - ''' - This function hijacks `k_diffusion.external.CompVisDenoiser.forward()` - So its signature should be the same as the original function, especially the "cond" should be with exactly the same name - ''' - assert CompVisDenoiser.forward - def org_func(x:Tensor): + def org_func(x:Tensor) -> Tensor: return self.sampler_forward(x, sigma_in, cond=cond) - def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]): + def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor: # For kdiff sampler, the dim 0 of input x_in is: # = batch_size * (num_AND + 1) if not an edit model # = batch_size * (num_AND + 2) otherwise - sigma_in_tile = sigma_in.repeat(len(bboxes)) - new_cond = self.repeat_cond_dict(cond, bboxes) - x_tile_out = self.sampler_forward(x_tile, sigma_in_tile, cond=new_cond) + sigma_tile = self._rep_dim0(sigma_in, len(bboxes)) + cond_tile = self.repeat_cond_dict(cond, bboxes) + x_tile_out = self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) return x_tile_out - def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox): + def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor: return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward) return self.sample_one_step(x_in, org_func, repeat_func, custom_func) @torch.no_grad() @keep_signature - def ddim_forward(self, x_in:Tensor, cond_in:Union[CondDict, Tensor], ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: - ''' - This function will replace the original p_sample_ddim function in ldm/diffusionmodels/ddim.py - So its signature should be the same as the original function, - Particularly, the unconditional_conditioning should be with exactly the same name - ''' + def ddim_forward(self, p:Processing, x_in:Tensor, cond_in:Union[CondDict, Tensor], ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: + assert VanillaStableDiffusionSampler.sample + assert VanillaStableDiffusionSampler.sample_img2img - assert VanillaStableDiffusionSampler.p_sample_ddim_hook + def org_func(x:Tensor) -> Tensor: + return self.sampler_forward(x, p, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) - def org_func(x:Tensor): - return self.sampler_forward(x, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) - - def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]): - if isinstance(cond_in, dict): - ts_tile = ts.repeat(len(bboxes)) + def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]: + n_rep = len(bboxes) + if isinstance(cond_in, dict): # FIXME: when will enter this branch? + ts_tile = self._rep_dim0(ts, n_rep) cond_tile = self.repeat_cond_dict(cond_in, bboxes) ucond_tile = self.repeat_cond_dict(unconditional_conditioning, bboxes) else: - ts_tile = ts.repeat(len(bboxes)) - cond_shape = cond_in.shape - cond_tile = cond_in.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1)) - ucond_shape = unconditional_conditioning.shape - ucond_tile = unconditional_conditioning.repeat((len(bboxes),) + (1,) * (len(ucond_shape) - 1)) - x_tile_out, x_pred = self.sampler_forward( - x_tile, cond_tile, ts_tile, - unconditional_conditioning=ucond_tile, - *args, **kwargs) + ts_tile = self._rep_dim0(ts, n_rep) + cond_tile = self._rep_dim0(cond_in, n_rep) + ucond_tile = self._rep_dim0(unconditional_conditioning, n_rep) + x_tile_out, x_pred = self.sampler_forward(x_tile, cond_tile, ts_tile, unconditional_conditioning=ucond_tile, *args, **kwargs) return x_tile_out, x_pred - def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox): + def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor: # before the final forward, we can set the control tensor def forward_func(x, *args, **kwargs): self.set_custom_controlnet_tensors(bbox_id, 2*x.shape[0]) @@ -143,17 +119,48 @@ class MultiDiffusion(TiledDiffusion): return self.sample_one_step(x_in, org_func, repeat_func, custom_func) - def sample_one_step(self, x_in:Tensor, org_func: Callable, repeat_func:Callable, custom_func:Callable) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def _rep_dim0(self, x:Tensor, n:int) -> Tensor: + ''' repeat the tensor on it's first dim ''' + if n == 1: return x + shape = [n] + [-1] * (len(x.shape) - 1) # [N, 1, ...] + return x.expand(shape) # `expand` is much lighter than `tile` + + def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict: + ''' repeat cond_dict for a batch of tiles ''' + # n_repeat + breakpoint() + n_rep = len(bboxes) + cond_out = cond_in.copy() + # txt cond + tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] + tcond = self._rep_dim0(tcond, n_rep) + self.set_tcond(cond_out, tcond) + # img cond + icond = self.get_icond(cond_in) + if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] + icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) + else: # txt2img, [B=1, C=5, H=1, W=1] + icond = self._rep_dim0(icond, n_rep) + self.set_icond(cond_out, icond) + # vec cond (SDXL) + vcond = self.get_vcond(cond_in) # [B=1, D] + if vcond is not None: + vcond = self._rep_dim0(vcond, n_rep) # [B*N, D] + self.set_vcond(cond_out, vcond) + return cond_out + + def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Union[Tensor, Tuple[Tensor, Tensor]]: ''' this method splits the whole latent and process in tiles - x_in: current whole U-Net latent - org_func: original forward function, when use highres - - denoise_func: one step denoiser for grid tile - - denoise_custom_func: one step denoiser for custom tile + - repeat_func: one step denoiser for grid tile + - custom_func: one step denoiser for custom tile ''' N, C, H, W = x_in.shape - if H != self.h or W != self.w: + if (H, W) != (self.h, self.w): + # We don't tile highres, let's just use the original org_func self.reset_controlnet_tensors() return org_func(x_in) @@ -166,13 +173,10 @@ class MultiDiffusion(TiledDiffusion): if state.interrupted: return x_in # batching - x_tile_list = [] - for bbox in bboxes: - x_tile_list.append(x_in[bbox.slicer]) - x_tile = torch.cat(x_tile_list, dim=0) + x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW] # controlnet tiling - # FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers + # FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers, why? self.switch_controlnet_tensors(batch_id, N, len(bboxes)) # stablesr tiling @@ -262,30 +266,30 @@ class MultiDiffusion(TiledDiffusion): x_pred_out = torch.where(x_feather_count > 0, x_pred_out * (1 - x_feather_mask) + x_feather_pred_buffer * x_feather_mask, x_pred_out) return x_out if self.is_kdiff else (x_out, x_pred_out) - def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: # NOTE: The following code is analytically wrong but aesthetically beautiful - local_cond_in = cond_in.copy() + cond_in_original = cond_in.copy() + def org_func(x:Tensor): - return shared.sd_model.apply_model(x, sigma_in, cond=local_cond_in) + return shared.sd_model.apply_model(x, sigma_in, cond=cond_in_original) def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]): sigma_in_tile = sigma_in.repeat(len(bboxes)) - new_cond = self.repeat_cond_dict(local_cond_in, bboxes) - x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=new_cond) + cond_out = self.repeat_cond_dict(cond_in_original, bboxes) + x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) return x_tile_out def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox): # The negative prompt in custom bbox should not be used for noise inversion # otherwise the result will be astonishingly bad. - cond = Condition.reconstruct_cond(bbox.cond, step) - image_cond = self.get_image_cond(local_cond_in) - if image_cond.shape[2:] == (self.h, self.w): - image_cond = image_cond[bbox.slicer] - image_conditioning = image_cond - self.make_condition_dict([cond], image_conditioning) - return shared.sd_model.apply_model(x, sigma_in, cond=cond_in) + cond_out: CondDict = cond_in.copy() + tcond = Condition.reconstruct_cond(bbox.cond, step).unsqueeze_(0) + self.set_tcond(cond_out, tcond) + icond = self.get_icond(cond_in_original) + if icond.shape[2:] == (self.h, self.w): + icond = icond[bbox.slicer] + self.set_icond(cond_out, icond) + return shared.sd_model.apply_model(x, sigma_in, cond=cond_out) return self.sample_one_step(x_in, org_func, repeat_func, custom_func) - diff --git a/tile_utils/typing.py b/tile_utils/typing.py index ff63477..6ea75ca 100644 --- a/tile_utils/typing.py +++ b/tile_utils/typing.py @@ -11,7 +11,7 @@ from modules.processing import StableDiffusionProcessing as Processing, StableDi from modules.prompt_parser import MulticondLearnedConditioning, ScheduledPromptConditioning from modules.extra_networks import ExtraNetworkParams from modules.sd_samplers_kdiffusion import KDiffusionSampler, CFGDenoiser -from modules.sd_samplers_compvis import VanillaStableDiffusionSampler +from modules.sd_samplers_timesteps import VanillaStableDiffusionSampler ModuleType = type(sys) @@ -20,7 +20,9 @@ Cond = MulticondLearnedConditioning Uncond = List[List[ScheduledPromptConditioning]] ExtraNetworkData = DefaultDict[str, List[ExtraNetworkParams]] -# 'c_crossattn': Tensor # prompt cond -# 'c_concat': Tensor # latent mask -# 'c_adm': Tensor # unclip +# 'c_crossattn' List[Tensor[B, L=77, D=768]] prompt cond (tcond) +# 'c_concat' List[Tensor[B, C=5, H, W]] latent mask (icond) +# 'c_adm' Tensor[?] unclip (icond) +# 'crossattn' Tensor[B, L=77, D=2048] sdxl (tcond) +# 'vector' Tensor[B, D] sdxl (tcond) CondDict = Dict[str, Tensor]