multidiffusion-upscaler-for.../tile_methods/multidiffusion.py

263 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
from modules import devices, extra_networks
from modules.shared import state
from tile_utils.typing import *
from tile_utils.utils import *
from tile_methods.abstractdiffusion import TiledDiffusion
class MultiDiffusion(TiledDiffusion):
"""
Multi-Diffusion Implementation
https://arxiv.org/abs/2302.08113
"""
def __init__(self, p:StableDiffusionProcessing, *args, **kwargs):
super().__init__(p, *args, **kwargs)
assert p.sampler_name != 'UniPC', 'MultiDiffusion is not compatible with UniPC!'
# For ddim sampler we need to cache the pred_x0
self.x_pred_buffer = None
def hook(self):
if self.is_kdiff:
# For K-Diffusion sampler with uniform prompt, we hijack into the inner model for simplicity
# Otherwise, the masked-redraw will break due to the init_latent
self.sampler: CFGDenoiser
self.sampler_forward = self.sampler.inner_model.forward
self.sampler.inner_model.forward = self.kdiff_forward
else:
self.sampler: VanillaStableDiffusionSampler
self.sampler_forward = self.sampler.orig_p_sample_ddim
self.sampler.orig_p_sample_ddim = self.ddim_forward
@staticmethod
def unhook():
# no need to unhook MultiDiffusion as it only hook the sampler,
# which will be destroyed after the painting is done
pass
def reset_buffer(self, x_in:Tensor):
super().reset_buffer(x_in)
# ddim needs to cache pred0
if self.is_ddim:
if self.x_pred_buffer is None:
self.x_pred_buffer = torch.zeros_like(x_in, device=x_in.device)
else:
self.x_pred_buffer.zero_()
@custom_bbox
def init_custom_bbox(self, *args):
super().init_custom_bbox(*args)
for bbox in self.custom_bboxes:
if bbox.blend_mode == BlendMode.BACKGROUND:
self.weights[bbox.slicer] += 1.0
''' ↓↓↓ kernel hijacks ↓↓↓ '''
def repeat_cond_dict(self, cond_input:CondDict, bboxes:List[CustomBBox]) -> CondDict:
cond = cond_input['c_crossattn'][0]
# repeat the condition on its first dim
cond_shape = cond.shape
cond = cond.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1))
image_cond = cond_input['c_concat'][0]
if image_cond.shape[2] == self.h and image_cond.shape[3] == self.w:
image_cond_list = []
for bbox in bboxes:
image_cond_list.append(image_cond[bbox.slicer])
image_cond_tile = torch.cat(image_cond_list, dim=0)
else:
image_cond_shape = image_cond.shape
image_cond_tile = image_cond.repeat((len(bboxes),) + (1,) * (len(image_cond_shape) - 1))
return {"c_crossattn": [cond], "c_concat": [image_cond_tile]}
@torch.no_grad()
@keep_signature
def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict):
'''
This function hijacks `k_diffusion.external.CompVisDenoiser.forward()`
So its signature should be the same as the original function, especially the "cond" should be with exactly the same name
'''
assert CompVisDenoiser.forward
# x_in: [B, C=4, H=64, W=64]
# sigma_in [1]
# cond['c_crossattn'][0]: [1, 77, 768]
def org_func(x):
return self.sampler_forward(x, sigma_in, cond=cond)
def repeat_func(x_tile, bboxes):
# For kdiff sampler, the dim 0 of input x_in is:
# = batch_size * (num_AND + 1) if not an edit model
# = batch_size * (num_AND + 2) otherwise
sigma_in_tile = sigma_in.repeat(len(bboxes))
new_cond = self.repeat_cond_dict(cond, bboxes)
x_tile_out = self.sampler_forward(x_tile, sigma_in_tile, cond=new_cond)
return x_tile_out
def custom_func(x, bbox_id, bbox):
return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward)
return self.sample_one_step(x_in, org_func, repeat_func, custom_func)
@torch.no_grad()
@keep_signature
def ddim_forward(self, x_in:Tensor, cond_in:CondDict, ts:Tensor, unconditional_conditioning:Tensor, *args, **kwargs):
'''
This function will replace the original p_sample_ddim function in ldm/diffusionmodels/ddim.py
So its signature should be the same as the original function,
Particularly, the unconditional_conditioning should be with exactly the same name
'''
assert VanillaStableDiffusionSampler.p_sample_ddim_hook
def org_func(x):
return self.sampler_forward(x, cond_in, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
def repeat_func(x_tile, bboxes):
if isinstance(cond_in, dict):
ts_tile = ts.repeat(len(bboxes))
cond_tile = self.repeat_cond_dict(cond_in, bboxes)
ucond_tile = self.repeat_cond_dict(unconditional_conditioning, bboxes)
else:
ts_tile = ts.repeat(len(bboxes))
cond_shape = cond_in.shape
cond_tile = cond_in.repeat((len(bboxes),) + (1,) * (len(cond_shape) - 1))
ucond_shape = unconditional_conditioning.shape
ucond_tile = unconditional_conditioning.repeat((len(bboxes),) + (1,) * (len(ucond_shape) - 1))
x_tile_out, x_pred = self.sampler_forward(
x_tile, cond_tile, ts_tile,
unconditional_conditioning=ucond_tile,
*args, **kwargs)
return x_tile_out, x_pred
def custom_func(x, bbox_id:int, bbox:CustomBBox):
# before the final forward, we can set the control tensor
def forward_func(x, *args, **kwargs):
self.set_controlnet_tensors(bbox_id, 2*x.shape[0])
return self.sampler_forward(x, *args, **kwargs)
return self.ddim_custom_forward(x, cond_in, bbox, ts, forward_func, *args, **kwargs)
return self.sample_one_step(x_in, org_func, repeat_func, custom_func)
def sample_one_step(self, x_in:Tensor, org_func: Callable, repeat_func:Callable, custom_func:Callable):
'''
this method splits the whole latent and process in tiles
- x_in: current whole U-Net latent
- org_func: original forward function, when use highres
- denoise_func: one step denoiser for grid tile
- denoise_custom_func: one step denoiser for custom tile
'''
N, C, H, W = x_in.shape
if H != self.h or W != self.w:
self.reset_controlnet_tensors()
return org_func(x_in)
# clear buffer canvas
self.reset_buffer(x_in)
# Background sampling (grid bbox)
if self.draw_background:
for batch_id, bboxes in enumerate(self.batched_bboxes):
if state.interrupted: return x_in
# batching
x_tile_list = []
for bbox in bboxes:
x_tile_list.append(x_in[bbox.slicer])
x_tile = torch.cat(x_tile_list, dim=0)
# controlnet tiling
# FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers
self.switch_controlnet_tensors(batch_id, N, len(bboxes))
# compute tiles
if self.is_kdiff:
x_tile_out = repeat_func(x_tile, bboxes)
for i, bbox in enumerate(bboxes):
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]
else:
x_tile_out, x_tile_pred = repeat_func(x_tile, bboxes)
for i, bbox in enumerate(bboxes):
self.x_buffer [bbox.slicer] += x_tile_out [i*N:(i+1)*N, :, :, :]
self.x_pred_buffer[bbox.slicer] += x_tile_pred[i*N:(i+1)*N, :, :, :]
# update progress bar
self.update_pbar()
# Custom region sampling (custom bbox)
x_feather_buffer = None
x_feather_mask = None
x_feather_count = None
x_feather_pred_buffer = None
if len(self.custom_bboxes) > 0:
for bbox_id, bbox in enumerate(self.custom_bboxes):
if state.interrupted: return x_in
if not self.p.disable_extra_networks:
with devices.autocast():
extra_networks.activate(self.p, bbox.extra_network_data)
x_tile = x_in[bbox.slicer]
if self.is_kdiff:
# retrieve original x_in from construncted input
x_tile_out = custom_func(x_tile, bbox_id, bbox)
if bbox.blend_mode == BlendMode.BACKGROUND:
self.x_buffer[bbox.slicer] += x_tile_out
elif bbox.blend_mode == BlendMode.FOREGROUND:
if x_feather_buffer is None:
x_feather_buffer = torch.zeros_like(self.x_buffer)
x_feather_mask = torch.zeros_like(self.x_buffer)
x_feather_count = torch.zeros_like(self.x_buffer)
x_feather_buffer[bbox.slicer] += x_tile_out
x_feather_mask [bbox.slicer] += bbox.feather_mask
x_feather_count [bbox.slicer] += 1
else:
x_tile_out, x_tile_pred = custom_func(x_tile, bbox_id, bbox)
if bbox.blend_mode == BlendMode.BACKGROUND:
self.x_buffer [bbox.slicer] += x_tile_out
self.x_pred_buffer[bbox.slicer] += x_tile_pred
elif bbox.blend_mode == BlendMode.FOREGROUND:
if x_feather_buffer is None:
x_feather_buffer = torch.zeros_like(self.x_buffer)
x_feather_mask = torch.zeros_like(self.x_buffer)
x_feather_count = torch.zeros_like(self.x_buffer)
x_feather_pred_buffer = torch.zeros_like(self.x_pred_buffer)
x_feather_buffer [bbox.slicer] += x_tile_out
x_feather_mask [bbox.slicer] += bbox.feather_mask
x_feather_count [bbox.slicer] += 1
x_feather_pred_buffer[bbox.slicer] += x_tile_pred
if not self.p.disable_extra_networks:
with devices.autocast():
extra_networks.deactivate(self.p, bbox.extra_network_data)
# update progress bar
self.update_pbar()
# Averaging background buffer
x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer)
if self.is_ddim:
x_pred_out = torch.where(self.weights > 1, self.x_pred_buffer / self.weights, self.x_pred_buffer)
# Foreground Feather blending
if x_feather_buffer is not None:
# Average overlapping feathered regions
x_feather_buffer = torch.where(x_feather_count > 1, x_feather_buffer / x_feather_count, x_feather_buffer)
x_feather_mask = torch.where(x_feather_count > 1, x_feather_mask / x_feather_count, x_feather_mask)
# Weighted average with original x_buffer
x_out = torch.where(x_feather_count > 0, x_out * (1 - x_feather_mask) + x_feather_buffer * x_feather_mask, x_out)
if self.is_ddim:
x_feather_pred_buffer = torch.where(x_feather_count > 1, x_feather_pred_buffer / x_feather_count, x_feather_pred_buffer)
x_pred_out = torch.where(x_feather_count > 0, x_pred_out * (1 - x_feather_mask) + x_feather_pred_buffer * x_feather_mask, x_pred_out)
return x_out if self.is_kdiff else (x_out, x_pred_out)