parent
77e800fe2a
commit
ef40371b23
|
|
@ -72,7 +72,6 @@ from tile_methods.abstractdiffusion import AbstractDiffusion
|
|||
from tile_methods.multidiffusion import MultiDiffusion
|
||||
from tile_methods.mixtureofdiffusers import MixtureOfDiffusers
|
||||
from tile_utils.utils import *
|
||||
from tile_utils.typing import *
|
||||
|
||||
CFG_PATH = os.path.join(scripts.basedir(), 'region_configs')
|
||||
BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16)
|
||||
|
|
|
|||
|
|
@ -8,13 +8,9 @@ import torch.nn.functional as F
|
|||
from tqdm import tqdm
|
||||
|
||||
from modules import devices, shared, sd_samplers_common
|
||||
from modules.shared import state
|
||||
from modules.shared_state import State
|
||||
state: State
|
||||
from modules.processing import opt_f
|
||||
|
||||
from tile_utils.utils import *
|
||||
from tile_utils.typing import *
|
||||
|
||||
|
||||
class AbstractDiffusion:
|
||||
|
|
@ -27,8 +23,7 @@ class AbstractDiffusion:
|
|||
# sampler
|
||||
self.sampler_name = p.sampler_name
|
||||
self.sampler_raw = sampler
|
||||
if self.is_kdiff: self.sampler: CFGDenoiser = sampler.model_wrap_cfg
|
||||
else: self.sampler: VanillaStableDiffusionSampler = sampler
|
||||
self.sampler = sampler
|
||||
|
||||
# fix. Kdiff 'AND' support and image editing model support
|
||||
if self.is_kdiff and not hasattr(self, 'is_edit_model'):
|
||||
|
|
@ -97,7 +92,7 @@ class AbstractDiffusion:
|
|||
|
||||
@property
|
||||
def is_ddim(self):
|
||||
return isinstance(self.sampler_raw, VanillaStableDiffusionSampler)
|
||||
return isinstance(self.sampler_raw, CompVisSampler)
|
||||
|
||||
def update_pbar(self):
|
||||
if self.pbar.n >= self.pbar.total:
|
||||
|
|
@ -176,8 +171,9 @@ class AbstractDiffusion:
|
|||
if key is not None:
|
||||
cond_dict[key] = vcond
|
||||
|
||||
def make_cond_dict(self, cond_dict:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict:
|
||||
cond_out = cond_dict.copy()
|
||||
def make_cond_dict(self, cond_in:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict:
|
||||
''' copy & replace the content, returns a new object '''
|
||||
cond_out = cond_in.copy()
|
||||
self.set_tcond(cond_out, tcond)
|
||||
self.set_icond(cond_out, icond)
|
||||
self.set_vcond(cond_out, vcond)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import torch
|
||||
|
||||
from modules import devices, shared, extra_networks
|
||||
from modules.shared import state
|
||||
from modules.shared_state import State
|
||||
state: State
|
||||
|
||||
from tile_methods.abstractdiffusion import AbstractDiffusion
|
||||
from tile_utils.utils import *
|
||||
from tile_utils.typing import *
|
||||
|
||||
|
||||
class MixtureOfDiffusers(AbstractDiffusion):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import torch
|
||||
|
||||
from modules import devices, extra_networks
|
||||
from modules.shared import state
|
||||
from modules.shared_state import State
|
||||
state: State
|
||||
|
||||
from tile_methods.abstractdiffusion import AbstractDiffusion
|
||||
from tile_utils.utils import *
|
||||
from tile_utils.typing import *
|
||||
|
||||
|
||||
class MultiDiffusion(AbstractDiffusion):
|
||||
|
|
@ -20,20 +16,17 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
super().__init__(p, *args, **kwargs)
|
||||
assert p.sampler_name != 'UniPC', 'MultiDiffusion is not compatible with UniPC!'
|
||||
|
||||
# For ddim sampler we need to cache the pred_x0
|
||||
self.x_pred_buffer = None
|
||||
|
||||
def hook(self):
|
||||
if self.is_kdiff:
|
||||
# For K-Diffusion sampler with uniform prompt, we hijack into the inner model for simplicity
|
||||
# Otherwise, the masked-redraw will break due to the init_latent
|
||||
self.sampler: CFGDenoiser
|
||||
self.sampler_forward = self.sampler.inner_model.forward
|
||||
self.sampler.inner_model.forward = self.kdiff_forward
|
||||
self.sampler: KDiffusionSampler
|
||||
self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward
|
||||
self.sampler.model_wrap_cfg.inner_model.forward = self.kdiff_forward
|
||||
else:
|
||||
self.sampler: VanillaStableDiffusionSampler
|
||||
self.sampler_forward = self.sampler.orig_p_sample_ddim # FIXME: this is boken due to sd-webui's update
|
||||
self.sampler.orig_p_sample_ddim = self.ddim_forward
|
||||
self.sampler: CompVisSampler
|
||||
self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward
|
||||
self.sampler.model_wrap_cfg.inner_model.forward = self.ddim_forward
|
||||
|
||||
@staticmethod
|
||||
def unhook():
|
||||
|
|
@ -44,13 +37,6 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
def reset_buffer(self, x_in:Tensor):
|
||||
super().reset_buffer(x_in)
|
||||
|
||||
# ddim needs to cache pred0
|
||||
if self.is_ddim:
|
||||
if self.x_pred_buffer is None:
|
||||
self.x_pred_buffer = torch.zeros_like(x_in, device=x_in.device)
|
||||
else:
|
||||
self.x_pred_buffer.zero_()
|
||||
|
||||
@custom_bbox
|
||||
def init_custom_bbox(self, *args):
|
||||
super().init_custom_bbox(*args)
|
||||
|
|
@ -65,6 +51,7 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
@keep_signature
|
||||
def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict) -> Tensor:
|
||||
assert CompVisDenoiser.forward
|
||||
assert CompVisVDenoiser.forward
|
||||
|
||||
def org_func(x:Tensor) -> Tensor:
|
||||
return self.sampler_forward(x, sigma_in, cond=cond)
|
||||
|
|
@ -73,10 +60,9 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
# For kdiff sampler, the dim 0 of input x_in is:
|
||||
# = batch_size * (num_AND + 1) if not an edit model
|
||||
# = batch_size * (num_AND + 2) otherwise
|
||||
sigma_tile = self._rep_dim0(sigma_in, len(bboxes))
|
||||
sigma_tile = self.repeat_tensor(sigma_in, len(bboxes))
|
||||
cond_tile = self.repeat_cond_dict(cond, bboxes)
|
||||
x_tile_out = self.sampler_forward(x_tile, sigma_tile, cond=cond_tile)
|
||||
return x_tile_out
|
||||
return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile)
|
||||
|
||||
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor:
|
||||
return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward)
|
||||
|
|
@ -85,25 +71,21 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
|
||||
@torch.no_grad()
|
||||
@keep_signature
|
||||
def ddim_forward(self, p:Processing, x_in:Tensor, cond_in:Union[CondDict, Tensor], ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
|
||||
assert VanillaStableDiffusionSampler.sample
|
||||
assert VanillaStableDiffusionSampler.sample_img2img
|
||||
def ddim_forward(self, x_in:Tensor, ts_in:Tensor, cond:Union[CondDict, Tensor]) -> Tensor:
|
||||
assert CompVisTimestepsDenoiser.forward
|
||||
assert CompVisTimestepsVDenoiser.forward
|
||||
|
||||
def org_func(x:Tensor) -> Tensor:
|
||||
return self.sampler_forward(x, p, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
return self.sampler_forward(x, ts_in, cond=cond)
|
||||
|
||||
def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]:
|
||||
n_rep = len(bboxes)
|
||||
if isinstance(cond_in, dict): # FIXME: when will enter this branch?
|
||||
ts_tile = self._rep_dim0(ts, n_rep)
|
||||
cond_tile = self.repeat_cond_dict(cond_in, bboxes)
|
||||
ucond_tile = self.repeat_cond_dict(unconditional_conditioning, bboxes)
|
||||
ts_tile = self.repeat_tensor(ts_in, n_rep)
|
||||
if isinstance(cond, dict): # FIXME: when will enter this branch?
|
||||
cond_tile = self.repeat_cond_dict(cond, bboxes)
|
||||
else:
|
||||
ts_tile = self._rep_dim0(ts, n_rep)
|
||||
cond_tile = self._rep_dim0(cond_in, n_rep)
|
||||
ucond_tile = self._rep_dim0(unconditional_conditioning, n_rep)
|
||||
x_tile_out, x_pred = self.sampler_forward(x_tile, cond_tile, ts_tile, unconditional_conditioning=ucond_tile, *args, **kwargs)
|
||||
return x_tile_out, x_pred
|
||||
cond_tile = self.repeat_tensor(cond, n_rep)
|
||||
return self.sampler_forward(x_tile, ts_tile, cond=cond_tile)
|
||||
|
||||
def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor:
|
||||
# before the final forward, we can set the control tensor
|
||||
|
|
@ -111,36 +93,36 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
self.set_custom_controlnet_tensors(bbox_id, 2*x.shape[0])
|
||||
self.set_custom_stablesr_tensors(bbox_id)
|
||||
return self.sampler_forward(x, *args, **kwargs)
|
||||
return self.ddim_custom_forward(x, cond_in, bbox, ts, forward_func, *args, **kwargs)
|
||||
return self.ddim_custom_forward(x, cond, bbox, ts_in, forward_func)
|
||||
|
||||
return self.sample_one_step(x_in, org_func, repeat_func, custom_func)
|
||||
|
||||
def _rep_dim0(self, x:Tensor, n:int) -> Tensor:
|
||||
def repeat_tensor(self, x:Tensor, n:int) -> Tensor:
|
||||
''' repeat the tensor on it's first dim '''
|
||||
if n == 1: return x
|
||||
shape = [n] + [-1] * (len(x.shape) - 1) # [N, 1, ...]
|
||||
return x.expand(shape) # `expand` is much lighter than `tile`
|
||||
|
||||
def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict:
|
||||
''' repeat cond_dict for a batch of tiles '''
|
||||
''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object '''
|
||||
# n_repeat
|
||||
n_rep = len(bboxes)
|
||||
# txt cond
|
||||
tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D]
|
||||
tcond = self._rep_dim0(tcond, n_rep)
|
||||
tcond = self.repeat_tensor(tcond, n_rep)
|
||||
# img cond
|
||||
icond = self.get_icond(cond_in)
|
||||
if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W]
|
||||
icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0)
|
||||
else: # txt2img, [B=1, C=5, H=1, W=1]
|
||||
icond = self._rep_dim0(icond, n_rep)
|
||||
icond = self.repeat_tensor(icond, n_rep)
|
||||
# vec cond (SDXL)
|
||||
vcond = self.get_vcond(cond_in) # [B=1, D]
|
||||
if vcond is not None:
|
||||
vcond = self._rep_dim0(vcond, n_rep) # [B*N, D]
|
||||
vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D]
|
||||
return self.make_cond_dict(cond_in, tcond, icond, vcond)
|
||||
|
||||
def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||
def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Tensor:
|
||||
'''
|
||||
this method splits the whole latent and process in tiles
|
||||
- x_in: current whole U-Net latent
|
||||
|
|
@ -174,24 +156,17 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
self.switch_stablesr_tensors(batch_id)
|
||||
|
||||
# compute tiles
|
||||
if self.is_kdiff:
|
||||
x_tile_out = repeat_func(x_tile, bboxes)
|
||||
for i, bbox in enumerate(bboxes):
|
||||
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]
|
||||
else:
|
||||
x_tile_out, x_tile_pred = repeat_func(x_tile, bboxes)
|
||||
for i, bbox in enumerate(bboxes):
|
||||
self.x_buffer [bbox.slicer] += x_tile_out [i*N:(i+1)*N, :, :, :]
|
||||
self.x_pred_buffer[bbox.slicer] += x_tile_pred[i*N:(i+1)*N, :, :, :]
|
||||
x_tile_out = repeat_func(x_tile, bboxes)
|
||||
for i, bbox in enumerate(bboxes):
|
||||
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]
|
||||
|
||||
# update progress bar
|
||||
self.update_pbar()
|
||||
|
||||
# Custom region sampling (custom bbox)
|
||||
x_feather_buffer = None
|
||||
x_feather_mask = None
|
||||
x_feather_count = None
|
||||
x_feather_pred_buffer = None
|
||||
x_feather_buffer = None
|
||||
x_feather_mask = None
|
||||
x_feather_count = None
|
||||
if len(self.custom_bboxes) > 0:
|
||||
for bbox_id, bbox in enumerate(self.custom_bboxes):
|
||||
if state.interrupted: return x_in
|
||||
|
|
@ -202,36 +177,19 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
|
||||
x_tile = x_in[bbox.slicer]
|
||||
|
||||
if self.is_kdiff:
|
||||
# retrieve original x_in from construncted input
|
||||
x_tile_out = custom_func(x_tile, bbox_id, bbox)
|
||||
# retrieve original x_in from construncted input
|
||||
x_tile_out = custom_func(x_tile, bbox_id, bbox)
|
||||
|
||||
if bbox.blend_mode == BlendMode.BACKGROUND:
|
||||
self.x_buffer[bbox.slicer] += x_tile_out
|
||||
elif bbox.blend_mode == BlendMode.FOREGROUND:
|
||||
if x_feather_buffer is None:
|
||||
x_feather_buffer = torch.zeros_like(self.x_buffer)
|
||||
x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device)
|
||||
x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device)
|
||||
x_feather_buffer[bbox.slicer] += x_tile_out
|
||||
x_feather_mask [bbox.slicer] += bbox.feather_mask
|
||||
x_feather_count [bbox.slicer] += 1
|
||||
else:
|
||||
x_tile_out, x_tile_pred = custom_func(x_tile, bbox_id, bbox)
|
||||
|
||||
if bbox.blend_mode == BlendMode.BACKGROUND:
|
||||
self.x_buffer [bbox.slicer] += x_tile_out
|
||||
self.x_pred_buffer[bbox.slicer] += x_tile_pred
|
||||
elif bbox.blend_mode == BlendMode.FOREGROUND:
|
||||
if x_feather_buffer is None:
|
||||
x_feather_buffer = torch.zeros_like(self.x_buffer)
|
||||
x_feather_pred_buffer = torch.zeros_like(self.x_pred_buffer)
|
||||
x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device)
|
||||
x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device)
|
||||
x_feather_buffer [bbox.slicer] += x_tile_out
|
||||
x_feather_pred_buffer[bbox.slicer] += x_tile_pred
|
||||
x_feather_mask [bbox.slicer] += bbox.feather_mask
|
||||
x_feather_count [bbox.slicer] += 1
|
||||
if bbox.blend_mode == BlendMode.BACKGROUND:
|
||||
self.x_buffer[bbox.slicer] += x_tile_out
|
||||
elif bbox.blend_mode == BlendMode.FOREGROUND:
|
||||
if x_feather_buffer is None:
|
||||
x_feather_buffer = torch.zeros_like(self.x_buffer)
|
||||
x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device)
|
||||
x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device)
|
||||
x_feather_buffer[bbox.slicer] += x_tile_out
|
||||
x_feather_mask [bbox.slicer] += bbox.feather_mask
|
||||
x_feather_count [bbox.slicer] += 1
|
||||
|
||||
if not self.p.disable_extra_networks:
|
||||
with devices.autocast():
|
||||
|
|
@ -242,9 +200,7 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
|
||||
# Averaging background buffer
|
||||
x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer)
|
||||
if self.is_ddim:
|
||||
x_pred_out = torch.where(self.weights > 1, self.x_pred_buffer / self.weights, self.x_pred_buffer)
|
||||
|
||||
|
||||
# Foreground Feather blending
|
||||
if x_feather_buffer is not None:
|
||||
# Average overlapping feathered regions
|
||||
|
|
@ -252,11 +208,8 @@ class MultiDiffusion(AbstractDiffusion):
|
|||
x_feather_mask = torch.where(x_feather_count > 1, x_feather_mask / x_feather_count, x_feather_mask)
|
||||
# Weighted average with original x_buffer
|
||||
x_out = torch.where(x_feather_count > 0, x_out * (1 - x_feather_mask) + x_feather_buffer * x_feather_mask, x_out)
|
||||
if self.is_ddim:
|
||||
x_feather_pred_buffer = torch.where(x_feather_count > 1, x_feather_pred_buffer / x_feather_count, x_feather_pred_buffer)
|
||||
x_pred_out = torch.where(x_feather_count > 0, x_pred_out * (1 - x_feather_mask) + x_feather_pred_buffer * x_feather_mask, x_pred_out)
|
||||
|
||||
return x_out if self.is_kdiff else (x_out, x_pred_out)
|
||||
return x_out
|
||||
|
||||
def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor:
|
||||
# NOTE: The following code is analytically wrong but aesthetically beautiful
|
||||
|
|
|
|||
|
|
@ -4,18 +4,18 @@ from torch import Tensor
|
|||
|
||||
from gradio.components import Component
|
||||
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
|
||||
from modules.processing import StableDiffusionProcessing as Processing, StableDiffusionProcessingImg2Img as ProcessingImg2Img, Processed
|
||||
from modules.prompt_parser import MulticondLearnedConditioning, ScheduledPromptConditioning
|
||||
from modules.extra_networks import ExtraNetworkParams
|
||||
from modules.sd_samplers_kdiffusion import KDiffusionSampler, CFGDenoiser
|
||||
from modules.sd_samplers_timesteps import VanillaStableDiffusionSampler
|
||||
from modules.sd_samplers_timesteps import CompVisSampler, CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser
|
||||
|
||||
ModuleType = type(sys)
|
||||
|
||||
Sampler = Union[KDiffusionSampler, VanillaStableDiffusionSampler]
|
||||
Sampler = Union[KDiffusionSampler, CompVisSampler]
|
||||
Cond = MulticondLearnedConditioning
|
||||
Uncond = List[List[ScheduledPromptConditioning]]
|
||||
ExtraNetworkData = DefaultDict[str, List[ExtraNetworkParams]]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ import numpy as np
|
|||
|
||||
from modules import devices, shared, prompt_parser, extra_networks
|
||||
from modules.processing import opt_f
|
||||
from modules.shared import state
|
||||
from modules.shared_state import State
|
||||
state: State
|
||||
|
||||
from tile_utils.typing import *
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue