diff --git a/scripts/tilediffusion.py b/scripts/tilediffusion.py index cb51f69..5a4f1c2 100644 --- a/scripts/tilediffusion.py +++ b/scripts/tilediffusion.py @@ -72,7 +72,6 @@ from tile_methods.abstractdiffusion import AbstractDiffusion from tile_methods.multidiffusion import MultiDiffusion from tile_methods.mixtureofdiffusers import MixtureOfDiffusers from tile_utils.utils import * -from tile_utils.typing import * CFG_PATH = os.path.join(scripts.basedir(), 'region_configs') BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16) diff --git a/tile_methods/abstractdiffusion.py b/tile_methods/abstractdiffusion.py index 0c2626f..61cec3e 100644 --- a/tile_methods/abstractdiffusion.py +++ b/tile_methods/abstractdiffusion.py @@ -8,13 +8,9 @@ import torch.nn.functional as F from tqdm import tqdm from modules import devices, shared, sd_samplers_common -from modules.shared import state -from modules.shared_state import State -state: State from modules.processing import opt_f from tile_utils.utils import * -from tile_utils.typing import * class AbstractDiffusion: @@ -27,8 +23,7 @@ class AbstractDiffusion: # sampler self.sampler_name = p.sampler_name self.sampler_raw = sampler - if self.is_kdiff: self.sampler: CFGDenoiser = sampler.model_wrap_cfg - else: self.sampler: VanillaStableDiffusionSampler = sampler + self.sampler = sampler # fix. Kdiff 'AND' support and image editing model support if self.is_kdiff and not hasattr(self, 'is_edit_model'): @@ -97,7 +92,7 @@ class AbstractDiffusion: @property def is_ddim(self): - return isinstance(self.sampler_raw, VanillaStableDiffusionSampler) + return isinstance(self.sampler_raw, CompVisSampler) def update_pbar(self): if self.pbar.n >= self.pbar.total: @@ -176,8 +171,9 @@ class AbstractDiffusion: if key is not None: cond_dict[key] = vcond - def make_cond_dict(self, cond_dict:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict: - cond_out = cond_dict.copy() + def make_cond_dict(self, cond_in:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict: + ''' copy & replace the content, returns a new object ''' + cond_out = cond_in.copy() self.set_tcond(cond_out, tcond) self.set_icond(cond_out, icond) self.set_vcond(cond_out, vcond) diff --git a/tile_methods/mixtureofdiffusers.py b/tile_methods/mixtureofdiffusers.py index 31c0c79..bedb59a 100644 --- a/tile_methods/mixtureofdiffusers.py +++ b/tile_methods/mixtureofdiffusers.py @@ -1,13 +1,9 @@ import torch from modules import devices, shared, extra_networks -from modules.shared import state -from modules.shared_state import State -state: State from tile_methods.abstractdiffusion import AbstractDiffusion from tile_utils.utils import * -from tile_utils.typing import * class MixtureOfDiffusers(AbstractDiffusion): diff --git a/tile_methods/multidiffusion.py b/tile_methods/multidiffusion.py index 154af13..17201c8 100644 --- a/tile_methods/multidiffusion.py +++ b/tile_methods/multidiffusion.py @@ -1,13 +1,9 @@ import torch from modules import devices, extra_networks -from modules.shared import state -from modules.shared_state import State -state: State from tile_methods.abstractdiffusion import AbstractDiffusion from tile_utils.utils import * -from tile_utils.typing import * class MultiDiffusion(AbstractDiffusion): @@ -20,20 +16,17 @@ class MultiDiffusion(AbstractDiffusion): super().__init__(p, *args, **kwargs) assert p.sampler_name != 'UniPC', 'MultiDiffusion is not compatible with UniPC!' - # For ddim sampler we need to cache the pred_x0 - self.x_pred_buffer = None - def hook(self): if self.is_kdiff: # For K-Diffusion sampler with uniform prompt, we hijack into the inner model for simplicity # Otherwise, the masked-redraw will break due to the init_latent - self.sampler: CFGDenoiser - self.sampler_forward = self.sampler.inner_model.forward - self.sampler.inner_model.forward = self.kdiff_forward + self.sampler: KDiffusionSampler + self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward + self.sampler.model_wrap_cfg.inner_model.forward = self.kdiff_forward else: - self.sampler: VanillaStableDiffusionSampler - self.sampler_forward = self.sampler.orig_p_sample_ddim # FIXME: this is boken due to sd-webui's update - self.sampler.orig_p_sample_ddim = self.ddim_forward + self.sampler: CompVisSampler + self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward + self.sampler.model_wrap_cfg.inner_model.forward = self.ddim_forward @staticmethod def unhook(): @@ -44,13 +37,6 @@ class MultiDiffusion(AbstractDiffusion): def reset_buffer(self, x_in:Tensor): super().reset_buffer(x_in) - # ddim needs to cache pred0 - if self.is_ddim: - if self.x_pred_buffer is None: - self.x_pred_buffer = torch.zeros_like(x_in, device=x_in.device) - else: - self.x_pred_buffer.zero_() - @custom_bbox def init_custom_bbox(self, *args): super().init_custom_bbox(*args) @@ -65,6 +51,7 @@ class MultiDiffusion(AbstractDiffusion): @keep_signature def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict) -> Tensor: assert CompVisDenoiser.forward + assert CompVisVDenoiser.forward def org_func(x:Tensor) -> Tensor: return self.sampler_forward(x, sigma_in, cond=cond) @@ -73,10 +60,9 @@ class MultiDiffusion(AbstractDiffusion): # For kdiff sampler, the dim 0 of input x_in is: # = batch_size * (num_AND + 1) if not an edit model # = batch_size * (num_AND + 2) otherwise - sigma_tile = self._rep_dim0(sigma_in, len(bboxes)) + sigma_tile = self.repeat_tensor(sigma_in, len(bboxes)) cond_tile = self.repeat_cond_dict(cond, bboxes) - x_tile_out = self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) - return x_tile_out + return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor: return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward) @@ -85,25 +71,21 @@ class MultiDiffusion(AbstractDiffusion): @torch.no_grad() @keep_signature - def ddim_forward(self, p:Processing, x_in:Tensor, cond_in:Union[CondDict, Tensor], ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: - assert VanillaStableDiffusionSampler.sample - assert VanillaStableDiffusionSampler.sample_img2img + def ddim_forward(self, x_in:Tensor, ts_in:Tensor, cond:Union[CondDict, Tensor]) -> Tensor: + assert CompVisTimestepsDenoiser.forward + assert CompVisTimestepsVDenoiser.forward def org_func(x:Tensor) -> Tensor: - return self.sampler_forward(x, p, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + return self.sampler_forward(x, ts_in, cond=cond) def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]: n_rep = len(bboxes) - if isinstance(cond_in, dict): # FIXME: when will enter this branch? - ts_tile = self._rep_dim0(ts, n_rep) - cond_tile = self.repeat_cond_dict(cond_in, bboxes) - ucond_tile = self.repeat_cond_dict(unconditional_conditioning, bboxes) + ts_tile = self.repeat_tensor(ts_in, n_rep) + if isinstance(cond, dict): # FIXME: when will enter this branch? + cond_tile = self.repeat_cond_dict(cond, bboxes) else: - ts_tile = self._rep_dim0(ts, n_rep) - cond_tile = self._rep_dim0(cond_in, n_rep) - ucond_tile = self._rep_dim0(unconditional_conditioning, n_rep) - x_tile_out, x_pred = self.sampler_forward(x_tile, cond_tile, ts_tile, unconditional_conditioning=ucond_tile, *args, **kwargs) - return x_tile_out, x_pred + cond_tile = self.repeat_tensor(cond, n_rep) + return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor: # before the final forward, we can set the control tensor @@ -111,36 +93,36 @@ class MultiDiffusion(AbstractDiffusion): self.set_custom_controlnet_tensors(bbox_id, 2*x.shape[0]) self.set_custom_stablesr_tensors(bbox_id) return self.sampler_forward(x, *args, **kwargs) - return self.ddim_custom_forward(x, cond_in, bbox, ts, forward_func, *args, **kwargs) + return self.ddim_custom_forward(x, cond, bbox, ts_in, forward_func) return self.sample_one_step(x_in, org_func, repeat_func, custom_func) - def _rep_dim0(self, x:Tensor, n:int) -> Tensor: + def repeat_tensor(self, x:Tensor, n:int) -> Tensor: ''' repeat the tensor on it's first dim ''' if n == 1: return x shape = [n] + [-1] * (len(x.shape) - 1) # [N, 1, ...] return x.expand(shape) # `expand` is much lighter than `tile` def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict: - ''' repeat cond_dict for a batch of tiles ''' + ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' # n_repeat n_rep = len(bboxes) # txt cond tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] - tcond = self._rep_dim0(tcond, n_rep) + tcond = self.repeat_tensor(tcond, n_rep) # img cond icond = self.get_icond(cond_in) if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) else: # txt2img, [B=1, C=5, H=1, W=1] - icond = self._rep_dim0(icond, n_rep) + icond = self.repeat_tensor(icond, n_rep) # vec cond (SDXL) vcond = self.get_vcond(cond_in) # [B=1, D] if vcond is not None: - vcond = self._rep_dim0(vcond, n_rep) # [B*N, D] + vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] return self.make_cond_dict(cond_in, tcond, icond, vcond) - def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Tensor: ''' this method splits the whole latent and process in tiles - x_in: current whole U-Net latent @@ -174,24 +156,17 @@ class MultiDiffusion(AbstractDiffusion): self.switch_stablesr_tensors(batch_id) # compute tiles - if self.is_kdiff: - x_tile_out = repeat_func(x_tile, bboxes) - for i, bbox in enumerate(bboxes): - self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] - else: - x_tile_out, x_tile_pred = repeat_func(x_tile, bboxes) - for i, bbox in enumerate(bboxes): - self.x_buffer [bbox.slicer] += x_tile_out [i*N:(i+1)*N, :, :, :] - self.x_pred_buffer[bbox.slicer] += x_tile_pred[i*N:(i+1)*N, :, :, :] + x_tile_out = repeat_func(x_tile, bboxes) + for i, bbox in enumerate(bboxes): + self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] # update progress bar self.update_pbar() # Custom region sampling (custom bbox) - x_feather_buffer = None - x_feather_mask = None - x_feather_count = None - x_feather_pred_buffer = None + x_feather_buffer = None + x_feather_mask = None + x_feather_count = None if len(self.custom_bboxes) > 0: for bbox_id, bbox in enumerate(self.custom_bboxes): if state.interrupted: return x_in @@ -202,36 +177,19 @@ class MultiDiffusion(AbstractDiffusion): x_tile = x_in[bbox.slicer] - if self.is_kdiff: - # retrieve original x_in from construncted input - x_tile_out = custom_func(x_tile, bbox_id, bbox) + # retrieve original x_in from construncted input + x_tile_out = custom_func(x_tile, bbox_id, bbox) - if bbox.blend_mode == BlendMode.BACKGROUND: - self.x_buffer[bbox.slicer] += x_tile_out - elif bbox.blend_mode == BlendMode.FOREGROUND: - if x_feather_buffer is None: - x_feather_buffer = torch.zeros_like(self.x_buffer) - x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device) - x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device) - x_feather_buffer[bbox.slicer] += x_tile_out - x_feather_mask [bbox.slicer] += bbox.feather_mask - x_feather_count [bbox.slicer] += 1 - else: - x_tile_out, x_tile_pred = custom_func(x_tile, bbox_id, bbox) - - if bbox.blend_mode == BlendMode.BACKGROUND: - self.x_buffer [bbox.slicer] += x_tile_out - self.x_pred_buffer[bbox.slicer] += x_tile_pred - elif bbox.blend_mode == BlendMode.FOREGROUND: - if x_feather_buffer is None: - x_feather_buffer = torch.zeros_like(self.x_buffer) - x_feather_pred_buffer = torch.zeros_like(self.x_pred_buffer) - x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device) - x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device) - x_feather_buffer [bbox.slicer] += x_tile_out - x_feather_pred_buffer[bbox.slicer] += x_tile_pred - x_feather_mask [bbox.slicer] += bbox.feather_mask - x_feather_count [bbox.slicer] += 1 + if bbox.blend_mode == BlendMode.BACKGROUND: + self.x_buffer[bbox.slicer] += x_tile_out + elif bbox.blend_mode == BlendMode.FOREGROUND: + if x_feather_buffer is None: + x_feather_buffer = torch.zeros_like(self.x_buffer) + x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device) + x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device) + x_feather_buffer[bbox.slicer] += x_tile_out + x_feather_mask [bbox.slicer] += bbox.feather_mask + x_feather_count [bbox.slicer] += 1 if not self.p.disable_extra_networks: with devices.autocast(): @@ -242,9 +200,7 @@ class MultiDiffusion(AbstractDiffusion): # Averaging background buffer x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer) - if self.is_ddim: - x_pred_out = torch.where(self.weights > 1, self.x_pred_buffer / self.weights, self.x_pred_buffer) - + # Foreground Feather blending if x_feather_buffer is not None: # Average overlapping feathered regions @@ -252,11 +208,8 @@ class MultiDiffusion(AbstractDiffusion): x_feather_mask = torch.where(x_feather_count > 1, x_feather_mask / x_feather_count, x_feather_mask) # Weighted average with original x_buffer x_out = torch.where(x_feather_count > 0, x_out * (1 - x_feather_mask) + x_feather_buffer * x_feather_mask, x_out) - if self.is_ddim: - x_feather_pred_buffer = torch.where(x_feather_count > 1, x_feather_pred_buffer / x_feather_count, x_feather_pred_buffer) - x_pred_out = torch.where(x_feather_count > 0, x_pred_out * (1 - x_feather_mask) + x_feather_pred_buffer * x_feather_mask, x_pred_out) - return x_out if self.is_kdiff else (x_out, x_pred_out) + return x_out def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: # NOTE: The following code is analytically wrong but aesthetically beautiful diff --git a/tile_utils/typing.py b/tile_utils/typing.py index 6ea75ca..2d6b639 100644 --- a/tile_utils/typing.py +++ b/tile_utils/typing.py @@ -4,18 +4,18 @@ from torch import Tensor from gradio.components import Component -from k_diffusion.external import CompVisDenoiser +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from ldm.models.diffusion.ddpm import LatentDiffusion from modules.processing import StableDiffusionProcessing as Processing, StableDiffusionProcessingImg2Img as ProcessingImg2Img, Processed from modules.prompt_parser import MulticondLearnedConditioning, ScheduledPromptConditioning from modules.extra_networks import ExtraNetworkParams from modules.sd_samplers_kdiffusion import KDiffusionSampler, CFGDenoiser -from modules.sd_samplers_timesteps import VanillaStableDiffusionSampler +from modules.sd_samplers_timesteps import CompVisSampler, CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser ModuleType = type(sys) -Sampler = Union[KDiffusionSampler, VanillaStableDiffusionSampler] +Sampler = Union[KDiffusionSampler, CompVisSampler] Cond = MulticondLearnedConditioning Uncond = List[List[ScheduledPromptConditioning]] ExtraNetworkData = DefaultDict[str, List[ExtraNetworkParams]] diff --git a/tile_utils/utils.py b/tile_utils/utils.py index d5b52e5..f573005 100644 --- a/tile_utils/utils.py +++ b/tile_utils/utils.py @@ -8,6 +8,9 @@ import numpy as np from modules import devices, shared, prompt_parser, extra_networks from modules.processing import opt_f +from modules.shared import state +from modules.shared_state import State +state: State from tile_utils.typing import *