diff --git a/scripts/tilediffusion.py b/scripts/tilediffusion.py index 40dd689..d23d67c 100644 --- a/scripts/tilediffusion.py +++ b/scripts/tilediffusion.py @@ -86,6 +86,7 @@ class Script(scripts.Script): def __init__(self): self.controlnet_script: ModuleType = None self.delegate: TiledDiffusion = None + self.noise_inverse_cache: NoiseInverseCache = None def title(self): return "Tiled Diffusion" @@ -440,7 +441,9 @@ class Script(scripts.Script): # setup **optional** supports through `init_*`, make everything relatively pluggable!! if flag_noise_inverse: - delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) + get_cache_callback = self.noise_inverse_get_cache + set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) + delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) if not enable_bbox_control or draw_background: delegate.init_grid_bbox(tile_width, tile_height, overlap, tile_batch_size) if enable_bbox_control: @@ -557,6 +560,12 @@ class Script(scripts.Script): data_list.extend(DEFAULT_BBOX_SETTINGS) return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] + + def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): + self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) + + def noise_inverse_get_cache(self): + return self.noise_inverse_cache def reset(self): ''' unhijack inner APIs ''' @@ -572,6 +581,7 @@ class Script(scripts.Script): def reset_and_gc(self): self.reset() + self.noise_inverse_cache = None import gc; gc.collect() devices.torch_gc() diff --git a/tile_methods/abstractdiffusion.py b/tile_methods/abstractdiffusion.py index 9f942d5..c95940b 100644 --- a/tile_methods/abstractdiffusion.py +++ b/tile_methods/abstractdiffusion.py @@ -15,7 +15,6 @@ from modules.processing import opt_f from tile_utils.utils import * from tile_utils.typing import * - class TiledDiffusion: def __init__(self, p:Processing, sampler:Sampler): @@ -64,12 +63,14 @@ class TiledDiffusion: self.causal_layers: bool = None # ext. Noise Inversion (noise inversion) + self.noise_inverse_enabled: bool = False + self.noise_inverse_steps: int = 0 + self.noise_inverse_retouch: float = None + self.noise_inverse_renoise_strength: float = None + self.noise_inverse_renoise_kernel: int = None + self.noise_inverse_get_cache = None + self.noise_inverse_set_cache = None self.sample_img2img_original = None - self.noise_inverse_enabled = None - self.noise_inverse_steps = None - self.noise_inverse_retouch = None - self.noise_inverse_renoise_strength = None - self.noise_inverse_renoise_kernel = None # ext. ControlNet self.enable_controlnet: bool = False @@ -498,7 +499,7 @@ class TiledDiffusion: @noise_inverse - def init_noise_inverse(self, steps:int, retouch:float, renoise_strength:float, renoise_kernel:int): + def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int): self.noise_inverse_enabled = True self.noise_inverse_steps = steps self.noise_inverse_retouch = float(retouch) @@ -507,6 +508,8 @@ class TiledDiffusion: if self.sample_img2img_original is None: self.sample_img2img_original = self.sampler_raw.sample_img2img self.sampler_raw.sample_img2img = MethodType(self.sample_img2img, self.sampler_raw) + self.noise_inverse_set_cache = set_cache_callback + self.noise_inverse_get_cache = get_cache_callback @noise_inverse @keep_signature @@ -527,16 +530,29 @@ class TiledDiffusion: renoise_mask = torch.clamp(renoise_mask, 0, 1) prompts = p.all_prompts[:p.batch_size] - - if hasattr(p, 'noise_inverse_latent'): - # in batch mode, use the same noise latent for all images - latent = p.noise_inverse_latent.to(noise.device) - else: + + latent = None + # try to use cached latent to save huge amount of time. + cached_latent: NoiseInverseCache = self.noise_inverse_get_cache() + if cached_latent is not None and \ + cached_latent.model_hash == p.sd_model.sd_model_hash and \ + cached_latent.noise_inversion_steps == self.noise_inverse_steps and \ + len(cached_latent.prompts) == len(prompts) and \ + all([cached_latent.prompts[i] == prompts[i] for i in range(len(prompts))]) and \ + abs(cached_latent.retouch - self.noise_inverse_retouch) < 0.01 and \ + cached_latent.x0.shape == p.init_latent.shape and \ + torch.abs(cached_latent.x0.to(p.init_latent.device) - p.init_latent).sum() < 100: # the 100 is an arbitrary threshold copy-pasted from the img2img alt code + # use cached noise + print('[Tiled Diffusion] Your checkpoint, image, prompts, inverse steps, and retouch params are all unchanged.') + print('[Tiled Diffusion] Noise Inversion will use the cached noise from the previous run. To clear the cache, click the Free GPU button.') + latent = cached_latent.xt.to(noise.device) + if latent is None: # run noise inversion shared.state.job_count += 1 latent = self.find_noise_for_image_sigma_adjustment(sampler.model_wrap, self.noise_inverse_steps, prompts) shared.state.nextjob() - p.noise_inverse_latent = latent.to(device=devices.cpu) if cmd_opts.lowvram else latent + self.noise_inverse_set_cache(p.init_latent.clone().cpu(), latent.clone().cpu(), prompts) + # The cache is only 1 latent image and is very small (16 MB for 8192 * 8192 image), so we don't need to worry about memory leakage. # calculate sampling steps adjusted_steps, _ = sd_samplers_common.setup_img2img_steps(p, steps) diff --git a/tile_utils/utils.py b/tile_utils/utils.py index bef0912..c72df12 100644 --- a/tile_utils/utils.py +++ b/tile_utils/utils.py @@ -29,8 +29,8 @@ class BlendMode(Enum): # i.e. LayerType FOREGROUND = 'Foreground' BACKGROUND = 'Background' - BBoxSettings = namedtuple('BBoxSettings', ['enable', 'x', 'y', 'w', 'h', 'prompt', 'neg_prompt', 'blend_mode', 'feather_ratio', 'seed']) +NoiseInverseCache = namedtuple('NoiseInversionCache', ['model_hash', 'x0', 'xt', 'noise_inversion_steps', 'retouch', 'prompts']) DEFAULT_BBOX_SETTINGS = BBoxSettings(False, 0.4, 0.4, 0.2, 0.2, '', '', BlendMode.BACKGROUND.value, 0.2, -1) NUM_BBOX_PARAMS = len(BBoxSettings._fields)