add noise inverse cache

pull/221/head
pkuliyi2015 2023-05-09 13:29:40 +00:00
parent 5f22b8fc27
commit 23f3a14432
3 changed files with 41 additions and 15 deletions

View File

@ -86,6 +86,7 @@ class Script(scripts.Script):
def __init__(self):
self.controlnet_script: ModuleType = None
self.delegate: TiledDiffusion = None
self.noise_inverse_cache: NoiseInverseCache = None
def title(self):
return "Tiled Diffusion"
@ -440,7 +441,9 @@ class Script(scripts.Script):
# setup **optional** supports through `init_*`, make everything relatively pluggable!!
if flag_noise_inverse:
delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel)
get_cache_callback = self.noise_inverse_get_cache
set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch)
delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel)
if not enable_bbox_control or draw_background:
delegate.init_grid_bbox(tile_width, tile_height, overlap, tile_batch_size)
if enable_bbox_control:
@ -557,6 +560,12 @@ class Script(scripts.Script):
data_list.extend(DEFAULT_BBOX_SETTINGS)
return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)]
def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float):
self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts)
def noise_inverse_get_cache(self):
return self.noise_inverse_cache
def reset(self):
''' unhijack inner APIs '''
@ -572,6 +581,7 @@ class Script(scripts.Script):
def reset_and_gc(self):
self.reset()
self.noise_inverse_cache = None
import gc; gc.collect()
devices.torch_gc()

View File

@ -15,7 +15,6 @@ from modules.processing import opt_f
from tile_utils.utils import *
from tile_utils.typing import *
class TiledDiffusion:
def __init__(self, p:Processing, sampler:Sampler):
@ -64,12 +63,14 @@ class TiledDiffusion:
self.causal_layers: bool = None
# ext. Noise Inversion (noise inversion)
self.noise_inverse_enabled: bool = False
self.noise_inverse_steps: int = 0
self.noise_inverse_retouch: float = None
self.noise_inverse_renoise_strength: float = None
self.noise_inverse_renoise_kernel: int = None
self.noise_inverse_get_cache = None
self.noise_inverse_set_cache = None
self.sample_img2img_original = None
self.noise_inverse_enabled = None
self.noise_inverse_steps = None
self.noise_inverse_retouch = None
self.noise_inverse_renoise_strength = None
self.noise_inverse_renoise_kernel = None
# ext. ControlNet
self.enable_controlnet: bool = False
@ -498,7 +499,7 @@ class TiledDiffusion:
@noise_inverse
def init_noise_inverse(self, steps:int, retouch:float, renoise_strength:float, renoise_kernel:int):
def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int):
self.noise_inverse_enabled = True
self.noise_inverse_steps = steps
self.noise_inverse_retouch = float(retouch)
@ -507,6 +508,8 @@ class TiledDiffusion:
if self.sample_img2img_original is None:
self.sample_img2img_original = self.sampler_raw.sample_img2img
self.sampler_raw.sample_img2img = MethodType(self.sample_img2img, self.sampler_raw)
self.noise_inverse_set_cache = set_cache_callback
self.noise_inverse_get_cache = get_cache_callback
@noise_inverse
@keep_signature
@ -527,16 +530,29 @@ class TiledDiffusion:
renoise_mask = torch.clamp(renoise_mask, 0, 1)
prompts = p.all_prompts[:p.batch_size]
if hasattr(p, 'noise_inverse_latent'):
# in batch mode, use the same noise latent for all images
latent = p.noise_inverse_latent.to(noise.device)
else:
latent = None
# try to use cached latent to save huge amount of time.
cached_latent: NoiseInverseCache = self.noise_inverse_get_cache()
if cached_latent is not None and \
cached_latent.model_hash == p.sd_model.sd_model_hash and \
cached_latent.noise_inversion_steps == self.noise_inverse_steps and \
len(cached_latent.prompts) == len(prompts) and \
all([cached_latent.prompts[i] == prompts[i] for i in range(len(prompts))]) and \
abs(cached_latent.retouch - self.noise_inverse_retouch) < 0.01 and \
cached_latent.x0.shape == p.init_latent.shape and \
torch.abs(cached_latent.x0.to(p.init_latent.device) - p.init_latent).sum() < 100: # the 100 is an arbitrary threshold copy-pasted from the img2img alt code
# use cached noise
print('[Tiled Diffusion] Your checkpoint, image, prompts, inverse steps, and retouch params are all unchanged.')
print('[Tiled Diffusion] Noise Inversion will use the cached noise from the previous run. To clear the cache, click the Free GPU button.')
latent = cached_latent.xt.to(noise.device)
if latent is None:
# run noise inversion
shared.state.job_count += 1
latent = self.find_noise_for_image_sigma_adjustment(sampler.model_wrap, self.noise_inverse_steps, prompts)
shared.state.nextjob()
p.noise_inverse_latent = latent.to(device=devices.cpu) if cmd_opts.lowvram else latent
self.noise_inverse_set_cache(p.init_latent.clone().cpu(), latent.clone().cpu(), prompts)
# The cache is only 1 latent image and is very small (16 MB for 8192 * 8192 image), so we don't need to worry about memory leakage.
# calculate sampling steps
adjusted_steps, _ = sd_samplers_common.setup_img2img_steps(p, steps)

View File

@ -29,8 +29,8 @@ class BlendMode(Enum): # i.e. LayerType
FOREGROUND = 'Foreground'
BACKGROUND = 'Background'
BBoxSettings = namedtuple('BBoxSettings', ['enable', 'x', 'y', 'w', 'h', 'prompt', 'neg_prompt', 'blend_mode', 'feather_ratio', 'seed'])
NoiseInverseCache = namedtuple('NoiseInversionCache', ['model_hash', 'x0', 'xt', 'noise_inversion_steps', 'retouch', 'prompts'])
DEFAULT_BBOX_SETTINGS = BBoxSettings(False, 0.4, 0.4, 0.2, 0.2, '', '', BlendMode.BACKGROUND.value, 0.2, -1)
NUM_BBOX_PARAMS = len(BBoxSettings._fields)