add noise inverse cache
parent
5f22b8fc27
commit
23f3a14432
|
|
@ -86,6 +86,7 @@ class Script(scripts.Script):
|
|||
def __init__(self):
|
||||
self.controlnet_script: ModuleType = None
|
||||
self.delegate: TiledDiffusion = None
|
||||
self.noise_inverse_cache: NoiseInverseCache = None
|
||||
|
||||
def title(self):
|
||||
return "Tiled Diffusion"
|
||||
|
|
@ -440,7 +441,9 @@ class Script(scripts.Script):
|
|||
|
||||
# setup **optional** supports through `init_*`, make everything relatively pluggable!!
|
||||
if flag_noise_inverse:
|
||||
delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel)
|
||||
get_cache_callback = self.noise_inverse_get_cache
|
||||
set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch)
|
||||
delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel)
|
||||
if not enable_bbox_control or draw_background:
|
||||
delegate.init_grid_bbox(tile_width, tile_height, overlap, tile_batch_size)
|
||||
if enable_bbox_control:
|
||||
|
|
@ -557,6 +560,12 @@ class Script(scripts.Script):
|
|||
data_list.extend(DEFAULT_BBOX_SETTINGS)
|
||||
|
||||
return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)]
|
||||
|
||||
def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float):
|
||||
self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts)
|
||||
|
||||
def noise_inverse_get_cache(self):
|
||||
return self.noise_inverse_cache
|
||||
|
||||
def reset(self):
|
||||
''' unhijack inner APIs '''
|
||||
|
|
@ -572,6 +581,7 @@ class Script(scripts.Script):
|
|||
|
||||
def reset_and_gc(self):
|
||||
self.reset()
|
||||
self.noise_inverse_cache = None
|
||||
|
||||
import gc; gc.collect()
|
||||
devices.torch_gc()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from modules.processing import opt_f
|
|||
from tile_utils.utils import *
|
||||
from tile_utils.typing import *
|
||||
|
||||
|
||||
class TiledDiffusion:
|
||||
|
||||
def __init__(self, p:Processing, sampler:Sampler):
|
||||
|
|
@ -64,12 +63,14 @@ class TiledDiffusion:
|
|||
self.causal_layers: bool = None
|
||||
|
||||
# ext. Noise Inversion (noise inversion)
|
||||
self.noise_inverse_enabled: bool = False
|
||||
self.noise_inverse_steps: int = 0
|
||||
self.noise_inverse_retouch: float = None
|
||||
self.noise_inverse_renoise_strength: float = None
|
||||
self.noise_inverse_renoise_kernel: int = None
|
||||
self.noise_inverse_get_cache = None
|
||||
self.noise_inverse_set_cache = None
|
||||
self.sample_img2img_original = None
|
||||
self.noise_inverse_enabled = None
|
||||
self.noise_inverse_steps = None
|
||||
self.noise_inverse_retouch = None
|
||||
self.noise_inverse_renoise_strength = None
|
||||
self.noise_inverse_renoise_kernel = None
|
||||
|
||||
# ext. ControlNet
|
||||
self.enable_controlnet: bool = False
|
||||
|
|
@ -498,7 +499,7 @@ class TiledDiffusion:
|
|||
|
||||
|
||||
@noise_inverse
|
||||
def init_noise_inverse(self, steps:int, retouch:float, renoise_strength:float, renoise_kernel:int):
|
||||
def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int):
|
||||
self.noise_inverse_enabled = True
|
||||
self.noise_inverse_steps = steps
|
||||
self.noise_inverse_retouch = float(retouch)
|
||||
|
|
@ -507,6 +508,8 @@ class TiledDiffusion:
|
|||
if self.sample_img2img_original is None:
|
||||
self.sample_img2img_original = self.sampler_raw.sample_img2img
|
||||
self.sampler_raw.sample_img2img = MethodType(self.sample_img2img, self.sampler_raw)
|
||||
self.noise_inverse_set_cache = set_cache_callback
|
||||
self.noise_inverse_get_cache = get_cache_callback
|
||||
|
||||
@noise_inverse
|
||||
@keep_signature
|
||||
|
|
@ -527,16 +530,29 @@ class TiledDiffusion:
|
|||
renoise_mask = torch.clamp(renoise_mask, 0, 1)
|
||||
|
||||
prompts = p.all_prompts[:p.batch_size]
|
||||
|
||||
if hasattr(p, 'noise_inverse_latent'):
|
||||
# in batch mode, use the same noise latent for all images
|
||||
latent = p.noise_inverse_latent.to(noise.device)
|
||||
else:
|
||||
|
||||
latent = None
|
||||
# try to use cached latent to save huge amount of time.
|
||||
cached_latent: NoiseInverseCache = self.noise_inverse_get_cache()
|
||||
if cached_latent is not None and \
|
||||
cached_latent.model_hash == p.sd_model.sd_model_hash and \
|
||||
cached_latent.noise_inversion_steps == self.noise_inverse_steps and \
|
||||
len(cached_latent.prompts) == len(prompts) and \
|
||||
all([cached_latent.prompts[i] == prompts[i] for i in range(len(prompts))]) and \
|
||||
abs(cached_latent.retouch - self.noise_inverse_retouch) < 0.01 and \
|
||||
cached_latent.x0.shape == p.init_latent.shape and \
|
||||
torch.abs(cached_latent.x0.to(p.init_latent.device) - p.init_latent).sum() < 100: # the 100 is an arbitrary threshold copy-pasted from the img2img alt code
|
||||
# use cached noise
|
||||
print('[Tiled Diffusion] Your checkpoint, image, prompts, inverse steps, and retouch params are all unchanged.')
|
||||
print('[Tiled Diffusion] Noise Inversion will use the cached noise from the previous run. To clear the cache, click the Free GPU button.')
|
||||
latent = cached_latent.xt.to(noise.device)
|
||||
if latent is None:
|
||||
# run noise inversion
|
||||
shared.state.job_count += 1
|
||||
latent = self.find_noise_for_image_sigma_adjustment(sampler.model_wrap, self.noise_inverse_steps, prompts)
|
||||
shared.state.nextjob()
|
||||
p.noise_inverse_latent = latent.to(device=devices.cpu) if cmd_opts.lowvram else latent
|
||||
self.noise_inverse_set_cache(p.init_latent.clone().cpu(), latent.clone().cpu(), prompts)
|
||||
# The cache is only 1 latent image and is very small (16 MB for 8192 * 8192 image), so we don't need to worry about memory leakage.
|
||||
|
||||
# calculate sampling steps
|
||||
adjusted_steps, _ = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ class BlendMode(Enum): # i.e. LayerType
|
|||
FOREGROUND = 'Foreground'
|
||||
BACKGROUND = 'Background'
|
||||
|
||||
|
||||
BBoxSettings = namedtuple('BBoxSettings', ['enable', 'x', 'y', 'w', 'h', 'prompt', 'neg_prompt', 'blend_mode', 'feather_ratio', 'seed'])
|
||||
NoiseInverseCache = namedtuple('NoiseInversionCache', ['model_hash', 'x0', 'xt', 'noise_inversion_steps', 'retouch', 'prompts'])
|
||||
DEFAULT_BBOX_SETTINGS = BBoxSettings(False, 0.4, 0.4, 0.2, 0.2, '', '', BlendMode.BACKGROUND.value, 0.2, -1)
|
||||
NUM_BBOX_PARAMS = len(BBoxSettings._fields)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue