diff --git a/README.md b/README.md index 450d219..e4db6af 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,14 @@ Please be aware that the License of this repo has changed to prevent some web sh **自许可证修改之日(AOE 2023.3.28)起,之后的版本禁止用于商业贩售 (不可贩售本仓库代码,但衍生的艺术创作内容物不受此限制)。** If you like the project, please give me a star! ⭐ - + [](https://ko-fi.com/pkuliyi2015) **** The extension enables **large image drawing & upscaling with limited VRAM** via the following techniques: -1. Two SOTA diffusion tiling algorithms: [Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) and [MultiDiffusion](https://multidiffusion.github.io) +1. Two SOTA diffusion tiling algorithms: [Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) and [MultiDiffusion](https://multidiffusion.github.io), add [Demofusion](https://github.com/PRIS-CV/DemoFusion) 2. My original Tiled VAE algorithm. 3. My original TIled Noise Inversion for better upscaling. diff --git a/README_CN.md b/README_CN.md index 6633690..7a807db 100644 --- a/README_CN.md +++ b/README_CN.md @@ -15,7 +15,7 @@ 本插件通过以下三种技术实现了 **在有限的显存中进行大型图像绘制**: -1. 两种 SOTA diffusion tiling 算法:[Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) 和 [MultiDiffusion](https://multidiffusion.github.io) +1. SOTA diffusion tiling 算法:[Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) 和 [MultiDiffusion](https://multidiffusion.github.io),新增[Demofusion](https://github.com/PRIS-CV/DemoFusion) 2. 原创的 Tiled VAE 算法。 3. 原创混合放大算法生成超高清图像 diff --git a/scripts/tileglobal.py b/scripts/tileglobal.py new file mode 100644 index 0000000..1e6ce96 --- /dev/null +++ b/scripts/tileglobal.py @@ -0,0 +1,508 @@ +import os +import json +import torch +import torch.nn.functional as F +import numpy as np +import gradio as gr + +from modules import sd_samplers, images, shared, devices, processing, scripts, sd_samplers_common, rng +from modules.shared import opts +from modules.processing import opt_f, get_fixed_seed +from modules.ui import gr_show + +from tile_methods.abstractdiffusion import AbstractDiffusion +from tile_methods.demofusion import DemoFusion +from tile_utils.utils import * + + +CFG_PATH = os.path.join(scripts.basedir(), 'region_configs') +BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16) + + + +class Script(scripts.Script): + def __init__(self): + self.controlnet_script: ModuleType = None + self.stablesr_script: ModuleType = None + self.delegate: AbstractDiffusion = None + self.noise_inverse_cache: NoiseInverseCache = None + + def title(self): + return 'demofusion' + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + tab = 't2i' if not is_img2img else 'i2i' + is_t2i = 'true' if not is_img2img else 'false' + uid = lambda name: f'MD-{tab}-{name}' + + with gr.Accordion('DemoFusion', open=False, elem_id=f'MD-{tab}'): + with gr.Row(variant='compact') as tab_enable: + enabled = gr.Checkbox(label='Enable DemoFusion(Do not open it with tilediffusion)', value=False, elem_id=uid('enabled')) + # overwrite_size = gr.Checkbox(label='Overwrite image size', value=False, visible=not is_img2img, elem_id=uid('overwrite-image-size')) + keep_input_size = gr.Checkbox(label='Keep input image size', value=True, visible=is_img2img, elem_id=uid('keep-input-size')) + random_jitter = gr.Checkbox(label='Random jitter windows', value=True, elem_id=uid('random-jitter')) + + # with gr.Row(variant='compact', visible=False) as tab_size: + # image_width = gr.Slider(minimum=256, maximum=16384, step=16, label='Image width', value=1024, elem_id=f'MD-overwrite-width-{tab}') + # image_height = gr.Slider(minimum=256, maximum=16384, step=16, label='Image height', value=1024, elem_id=f'MD-overwrite-height-{tab}') + # overwrite_size.change(fn=gr_show, inputs=overwrite_size, outputs=tab_size, show_progress=False) + + # with gr.Row(variant='compact', visible=True) as tab_size: + # c1 = gr.Slider(minimum=0.5, maximum=3, step=0.1, label='c1', value=3, elem_id=f'c1-{tab}') + # c2 = gr.Slider(minimum=0.5, maximum=3, step=0.1, label='c2', value=1, elem_id=f'c2-{tab}') + # c3 = gr.Slider(minimum=0.5, maximum=3, step=0.1, label='c3', value=1, elem_id=f'c3-{tab}') + + with gr.Row(variant='compact') as tab_param: + method = gr.Dropdown(label='Method', choices=[Method_2.DEMO_FU.value], value=Method_2.DEMO_FU.value, elem_id=uid('method-2')) + control_tensor_cpu = gr.Checkbox(label='Move ControlNet tensor to CPU (if applicable)', value=False, elem_id=uid('control-tensor-cpu-2')) + reset_status = gr.Button(value='Free GPU', variant='tool') + reset_status.click(fn=self.reset_and_gc, show_progress=False) + + with gr.Group() as tab_tile: + with gr.Row(variant='compact'): + window_size = gr.Slider(minimum=16, maximum=256, step=16, label='Latent window size', value=128, elem_id=uid('latent-window-size')) + # tile_height = gr.Slider(minimum=16, maximum=256, step=16, label='Latent tile height', value=96, elem_id=uid('latent-tile-height')) + + with gr.Row(variant='compact'): + overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent window overlap', value=64, elem_id=uid('latent-tile-overlap-2')) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Latent window batch size', value=4, elem_id=uid('latent-tile-batch-size-2')) + + with gr.Row(variant='compact', visible=True) as tab_size: + c1 = gr.Slider(minimum=0.5, maximum=3, step=0.1, label='c1', value=3, elem_id=f'c1-{tab}') + c2 = gr.Slider(minimum=0.5, maximum=3, step=0.1, label='c2', value=1, elem_id=f'c2-{tab}') + c3 = gr.Slider(minimum=0.5, maximum=3, step=0.1, label='c3', value=1, visible=False, elem_id=f'c3-{tab}') #XXX:this parameter is useless in current version + + with gr.Row(variant='compact') as tab_upscale: + # upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value='None', elem_id=uid('upscaler-index')) + scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=1, label='Scale_Factor', value=2.0, elem_id=uid('upscaler-factor-2')) + # scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=1, label='Overwrite Scale Factor', value=2.0,value=is_img2img, elem_id=uid('upscaler-factor')) + + with gr.Accordion('Noise Inversion', open=True, visible=is_img2img) as tab_noise_inv: + with gr.Row(variant='compact'): + noise_inverse = gr.Checkbox(label='Enable Noise Inversion', value=False, elem_id=uid('noise-inverse-2')) + noise_inverse_steps = gr.Slider(minimum=1, maximum=200, step=1, label='Inversion steps', value=10, elem_id=uid('noise-inverse-steps-2')) + gr.HTML('
Please test on small images before actual upscale. Default params require denoise <= 0.6
') + with gr.Row(variant='compact'): + noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch-2')) + noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength-2')) + noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel-2')) + + # The control includes txt2img and img2img, we use t2i and i2i to distinguish them + + return [ + enabled, method, + keep_input_size, + window_size, overlap, batch_size, + scale_factor, + noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel, + control_tensor_cpu, + random_jitter, + c1,c2,c3 + ] + + + def process(self, p: Processing, + enabled: bool, method: str, + keep_input_size: bool, + window_size:int, overlap: int, tile_batch_size: int, + scale_factor: float, + noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, + control_tensor_cpu: bool, + random_jitter:bool, + c1,c2,c3 + ): + + # unhijack & unhook, in case it broke at last time + self.reset() + + if not enabled: return + + ''' upscale ''' + # store canvas size settings + if hasattr(p, "init_images"): + p.init_images_original_md = [img.copy() for img in p.init_images] + + p.width_original_md = p.width + p.height_original_md = p.height + p.current_scale_num = 1 + p.scale_factor = int(scale_factor) + + is_img2img = hasattr(p, "init_images") and len(p.init_images) > 0 + if is_img2img: + init_img = p.init_images[0] + init_img = images.flatten(init_img, opts.img2img_background_color) + image = init_img + if keep_input_size: #若 scale factor为1则为真 + p.scale_factor = 1 + p.width = image.width + p.height = image.height + else: #XXX:To adapt to noise inversion, we do not multiply the scale factor here + p.width = p.width_original_md + p.height = p.height_original_md + else: # txt2img + p.width = p.width*(p.scale_factor) + p.height = p.height*(p.scale_factor) + + if 'png info': + info = {} + p.extra_generation_params["Tiled Diffusion"] = info + + info['Method'] = method + info['Window Size'] = window_size + info['Tile Overlap'] = overlap + info['Tile batch size'] = tile_batch_size + + if is_img2img: + info['Upscale factor'] = scale_factor + if keep_input_size: + info['Keep input size'] = keep_input_size + if noise_inverse: + info['NoiseInv'] = noise_inverse + info['NoiseInv Steps'] = noise_inverse_steps + info['NoiseInv Retouch'] = noise_inverse_retouch + info['NoiseInv Renoise strength'] = noise_inverse_renoise_strength + info['NoiseInv Kernel size'] = noise_inverse_renoise_kernel + + ''' ControlNet hackin ''' + try: + from scripts.cldm import ControlNet + + for script in p.scripts.scripts + p.scripts.alwayson_scripts: + if hasattr(script, "latest_network") and script.title().lower() == "controlnet": + self.controlnet_script = script + print("[Demo Fusion] ControlNet found, support is enabled.") + break + except ImportError: + pass + + ''' StableSR hackin ''' + for script in p.scripts.scripts: + if hasattr(script, "stablesr_model") and script.title().lower() == "stablesr": + if script.stablesr_model is not None: + self.stablesr_script = script + print("[Demo Fusion] StableSR found, support is enabled.") + break + + ''' hijack inner APIs, see unhijack in reset() ''' + Script.create_sampler_original_md = sd_samplers.create_sampler + + sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack( + name, model, p, Method_2(method), control_tensor_cpu,window_size, noise_inverse, noise_inverse_steps, noise_inverse_retouch, + noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter + ) + + + p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack( + conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img, + window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3) + ## end + + + def postprocess_batch(self, p: Processing, enabled, *args, **kwargs): + if not enabled: return + + if self.delegate is not None: self.delegate.reset_controlnet_tensors() + + def postprocess(self, p: Processing, processed, enabled, *args): + if not enabled: return + # unhijack & unhook + self.reset() + + # restore canvas size settings + if hasattr(p, 'init_images') and hasattr(p, 'init_images_original_md'): + p.init_images.clear() # NOTE: do NOT change the list object, compatible with shallow copy of XYZ-plot + p.init_images.extend(p.init_images_original_md) + del p.init_images_original_md + p.width = p.width_original_md ; del p.width_original_md + p.height = p.height_original_md ; del p.height_original_md + + # clean up noise inverse latent for folder-based processing + if hasattr(p, 'noise_inverse_latent'): + del p.noise_inverse_latent + + ''' ↓↓↓ inner API hijack ↓↓↓ ''' + @torch.no_grad() + def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3): + + if self.delegate==None: + p.denoising_strength=1 + # p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) + p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) #NOTE:Wrong but very useful. If corrected, please replace with the content from the previous line + # 3. Encode input prompts + shared.state.sampling_step = 0 + noise = p.rng.next() + + if hasattr(p,'initial_noise_multiplier'): + if p.initial_noise_multiplier != 1.0: + p.extra_generation_params["Noise multiplier"] = p.initial_noise_multiplier + noise *= p.initial_noise_multiplier + + ################################################## Phase Initialization ###################################################### + + if not image_ori: + latents = p.rng.next() #Same with line 233. Replaced with the following lines + # latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x)) + # del x + # p.denoising_strength=1 + # p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) + else: # img2img + print("### Encoding Real Image ###") + latents = p.init_latent + + + anchor_mean = latents.mean() + anchor_std = latents.std() + + devices.torch_gc() + + ####################################################### Phase Upscaling ##################################################### + starting_scale = 1 + p.cosine_scale_1 = c1 # 3 + p.cosine_scale_2 = c2 # 1 + p.cosine_scale_3 = c3 # 1 + p.latents = latents + for current_scale_num in range(starting_scale, p.scale_factor+1): + p.current_scale_num = current_scale_num + print("### Phase {} Denoising ###".format(current_scale_num)) + p.current_height = p.height_original_md * current_scale_num + p.current_width = p.width_original_md * current_scale_num + + + p.latents = F.interpolate(p.latents, size=(int(p.current_height / opt_f), int(p.current_width / opt_f)), mode='bicubic') + p.rng = rng.ImageRNG(p.latents.shape[1:], p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + + + self.delegate.w = int(p.current_width / opt_f) + self.delegate.h = int(p.current_height / opt_f) + if current_scale_num>1: + self.delegate.get_views(overlap, tile_batch_size) + + info = ', '.join([ + # f"{method.value} hooked into {name!r} sampler", + f"Tile size: {window_size}", + f"Tile count: {self.delegate.num_tiles}", + f"Batch size: {self.delegate.tile_bs}", + f"Tile batches: {len(self.delegate.batched_bboxes)}", + ]) + + print(info) + + noise = p.rng.next() + if hasattr(p,'initial_noise_multiplier'): + if p.initial_noise_multiplier != 1.0: + p.extra_generation_params["Noise multiplier"] = p.initial_noise_multiplier + noise *= p.initial_noise_multiplier + else: + p.image_conditioning = p.txt2img_image_conditioning(noise) + + p.noise = noise + p.x = p.latents.clone() + p.current_step=-1 + + p.latents = p.sampler.sample_img2img(p,p.latents, noise , conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) + if self.flag_noise_inverse: + self.delegate.sampler_raw.sample_img2img = self.delegate.sample_img2img_original + self.flag_noise_inverse = False + + p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean + ######################################################################################################################################### + p.width = p.width*p.scale_factor + p.height = p.height*p.scale_factor + return p.latents + + + def create_sampler_hijack( + self, name: str, model: LatentDiffusion, p: Processing, method: Method_2, control_tensor_cpu:bool,window_size, noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float, + noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool + ): + if self.delegate is not None: + # samplers are stateless, we reuse it if possible + if self.delegate.sampler_name == name: + # before we reuse the sampler, we refresh the control tensor + # so that we are compatible with ControlNet batch processing + if self.controlnet_script: + self.delegate.prepare_controlnet_tensors(refresh=True) + return self.delegate.sampler_raw + else: + self.reset() + + self.flag_noise_inverse = hasattr(p, "init_images") and len(p.init_images) > 0 and noise_inverse + flag_noise_inverse = self.flag_noise_inverse + if flag_noise_inverse: + print('warn: noise inversion only supports the "Euler" sampler, switch to it sliently...') + name = 'Euler' + p.sampler_name = 'Euler' + if name is None: print('>> name is empty') + if model is None: print('>> model is empty') + sampler = Script.create_sampler_original_md(name, model) + if method ==Method_2.DEMO_FU: delegate_cls = DemoFusion + else: raise NotImplementedError(f"Method {method} not implemented.") + + delegate = delegate_cls(p, sampler) + delegate.window_size = window_size + p.random_jitter = random_jitter + + if flag_noise_inverse: + get_cache_callback = self.noise_inverse_get_cache + set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) + delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) + + delegate.get_views(overlap,tile_batch_size) + if self.controlnet_script: + delegate.init_controlnet(self.controlnet_script, control_tensor_cpu) + if self.stablesr_script: + delegate.init_stablesr(self.stablesr_script) + + # init everything done, perform sanity check & pre-computations + # hijack the behaviours + delegate.hook() + + self.delegate = delegate + + info = ', '.join([ + f"{method.value} hooked into {name!r} sampler", + f"Tile size: {window_size}", + f"Tile count: {delegate.num_tiles}", + f"Batch size: {delegate.tile_bs}", + f"Tile batches: {len(delegate.batched_bboxes)}", + ]) + exts = [ + "ContrlNet" if self.controlnet_script else None, + "StableSR" if self.stablesr_script else None, + ] + ext_info = ', '.join([e for e in exts if e]) + if ext_info: ext_info = f' (ext: {ext_info})' + print(info + ext_info) + + return delegate.sampler_raw + + def create_random_tensors_hijack( + self, bbox_settings: Dict, region_info: Dict, + shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None, + ): + org_random_tensors = Script.create_random_tensors_original_md(shape, seeds, subseeds, subseed_strength, seed_resize_from_h, seed_resize_from_w, p) + height, width = shape[1], shape[2] + background_noise = torch.zeros_like(org_random_tensors) + background_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) + foreground_noise = torch.zeros_like(org_random_tensors) + foreground_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) + + for i, v in bbox_settings.items(): + seed = get_fixed_seed(v.seed) + x, y, w, h = v.x, v.y, v.w, v.h + # convert to pixel + x = int(x * width) + y = int(y * height) + w = math.ceil(w * width) + h = math.ceil(h * height) + # clamp + x = max(0, x) + y = max(0, y) + w = min(width - x, w) + h = min(height - y, h) + # create random tensor + torch.manual_seed(seed) + rand_tensor = torch.randn((1, org_random_tensors.shape[1], h, w), device=devices.cpu) + if BlendMode(v.blend_mode) == BlendMode.BACKGROUND: + background_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(background_noise.device) + background_noise_count[:, :, y:y+h, x:x+w] += 1 + elif BlendMode(v.blend_mode) == BlendMode.FOREGROUND: + foreground_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(foreground_noise.device) + foreground_noise_count[:, :, y:y+h, x:x+w] += 1 + else: + raise NotImplementedError + region_info['Region ' + str(i+1)]['seed'] = seed + + # average + background_noise = torch.where(background_noise_count > 1, background_noise / background_noise_count, background_noise) + foreground_noise = torch.where(foreground_noise_count > 1, foreground_noise / foreground_noise_count, foreground_noise) + # paste two layers to original random tensor + org_random_tensors = torch.where(background_noise_count > 0, background_noise, org_random_tensors) + org_random_tensors = torch.where(foreground_noise_count > 0, foreground_noise, org_random_tensors) + return org_random_tensors + # p.sd_model.sd_model_hash改为p.sd_model_hash + ''' ↓↓↓ helper methods ↓↓↓ ''' + + ''' ↓↓↓ helper methods ↓↓↓ ''' + + def dump_regions(self, cfg_name, *bbox_controls): + if not cfg_name: return gr_value(f'Config file name cannot be empty.', visible=True) + + bbox_settings = build_bbox_settings(bbox_controls) + data = {'bbox_controls': [v._asdict() for v in bbox_settings.values()]} + + if not os.path.exists(CFG_PATH): os.makedirs(CFG_PATH) + fp = os.path.join(CFG_PATH, cfg_name) + with open(fp, 'w', encoding='utf-8') as fh: + json.dump(data, fh, indent=2, ensure_ascii=False) + + return gr_value(f'Config saved to {fp}.', visible=True) + + def load_regions(self, ref_image, cfg_name, *bbox_controls): + if ref_image is None: + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Please create or upload a ref image first.', visible=True)] + fp = os.path.join(CFG_PATH, cfg_name) + if not os.path.exists(fp): + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Config {fp} not found.', visible=True)] + + try: + with open(fp, 'r', encoding='utf-8') as fh: + data = json.load(fh) + except Exception as e: + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Failed to load config {fp}: {e}', visible=True)] + + num_boxes = len(data['bbox_controls']) + data_list = [] + for i in range(BBOX_MAX_NUM): + if i < num_boxes: + for k in BBoxSettings._fields: + if k in data['bbox_controls'][i]: + data_list.append(data['bbox_controls'][i][k]) + else: + data_list.append(None) + else: + data_list.extend(DEFAULT_BBOX_SETTINGS) + + return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] + + + def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): + self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) + + def noise_inverse_get_cache(self): + return self.noise_inverse_cache + + + def reset(self): + ''' unhijack inner APIs, see hijack in process() ''' + if hasattr(Script, "create_sampler_original_md"): + sd_samplers.create_sampler = Script.create_sampler_original_md + del Script.create_sampler_original_md + if hasattr(Script, "create_random_tensors_original_md"): + processing.create_random_tensors = Script.create_random_tensors_original_md + del Script.create_random_tensors_original_md + DemoFusion.unhook() + self.delegate = None + + def reset_and_gc(self): + self.reset() + self.noise_inverse_cache = None + + import gc; gc.collect() + devices.torch_gc() + + try: + import os + import psutil + mem = psutil.Process(os.getpid()).memory_info() + print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') + from modules.shared import mem_mon as vram_mon + from modules.memmon import MemUsageMonitor + vram_mon: MemUsageMonitor + free, total = vram_mon.cuda_mem_get_info() + print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') + except: + pass diff --git a/tile_methods/demofusion.py b/tile_methods/demofusion.py new file mode 100644 index 0000000..291bc1f --- /dev/null +++ b/tile_methods/demofusion.py @@ -0,0 +1,334 @@ +from tile_methods.abstractdiffusion import AbstractDiffusion +from tile_utils.utils import * +import torch.nn.functional as F +import random +from copy import deepcopy +import inspect +from modules import sd_samplers_common + + +class DemoFusion(AbstractDiffusion): + """ + DemoFusion Implementation + https://arxiv.org/abs/2311.16973 + """ + + def __init__(self, p:Processing, *args, **kwargs): + super().__init__(p, *args, **kwargs) + assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' + + def add_one(self): + self.p.current_step += 1 + return + + + def hook(self): + steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) + # print("ENC",self.t_enc) + + self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward + self.sampler.model_wrap_cfg.forward = self.forward_one_step + self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward + if self.is_kdiff: + self.sampler: KDiffusionSampler + self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion + self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] + sigmas = self.sampler.get_sigmas(self.p, steps) + # print("SIGMAS:",sigmas) + self.p.sigmas = sigmas[steps - self.t_enc - 1:] + else: + self.sampler: CompVisSampler + self.sampler.model_wrap_cfg: CFGDenoiserTimesteps + self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] + self.timesteps = self.sampler.get_timesteps(self.p, steps) + + @staticmethod + def unhook(): + if hasattr(shared.sd_model, 'apply_model_ori'): + shared.sd_model.apply_model = shared.sd_model.apply_model_ori + del shared.sd_model.apply_model_ori + + def reset_buffer(self, x_in:Tensor): + super().reset_buffer(x_in) + + + + def repeat_tensor(self, x:Tensor, n:int) -> Tensor: + ''' repeat the tensor on it's first dim ''' + if n == 1: return x + B = x.shape[0] + r_dims = len(x.shape) - 1 + if B == 1: # batch_size = 1 (not `tile_batch_size`) + shape = [n] + [-1] * r_dims # [N, -1, ...] + return x.expand(shape) # `expand` is much lighter than `tile` + else: + shape = [n] + [1] * r_dims # [N, 1, ...] + return x.repeat(shape) + + def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict: + ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' + # n_repeat + n_rep = len(bboxes) + # txt cond + tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] + tcond = self.repeat_tensor(tcond, n_rep) + # img cond + icond = self.get_icond(cond_in) + if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] + icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) + else: # txt2img, [B=1, C=5, H=1, W=1] + icond = self.repeat_tensor(icond, n_rep) + # vec cond (SDXL) + vcond = self.get_vcond(cond_in) # [B=1, D] + if vcond is not None: + vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] + return self.make_cond_dict(cond_in, tcond, icond, vcond) + + + def global_split_bboxes(self): + cols = self.p.current_scale_num + rows = cols + + bbox_list = [] + for row in range(rows): + y = row + for col in range(cols): + x = col + bbox = (x, y) + bbox_list.append(bbox) + + return bbox_list + + def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: + cols = math.ceil((w_l - overlap) / (tile_w - overlap)) + rows = math.ceil((h_l - overlap) / (tile_h - overlap)) + if rows==0: + rows=1 + if cols == 0: + cols=1 + dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 + dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 + if self.p.random_jitter: + self.jitter_range = max((min(self.w, self.h)-self.stride)//4,0) + else: + self.jitter_range=0 + bbox_list: List[BBox] = [] + for row in range(rows): + for col in range(cols): + h = min(int(row * dy), h_l - tile_h) + w = min(int(col * dx), w_l - tile_w) + if self.p.random_jitter: + self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),int(self.stride/2)) + jitter_range = self.jitter_range + w_jitter = 0 + h_jitter = 0 + if (w != 0) and (w+tile_w != w_l): + w_jitter = random.randint(-jitter_range, jitter_range) + elif (w == 0) and (w + tile_w != w_l): + w_jitter = random.randint(-jitter_range, 0) + elif (w != 0) and (w + tile_w == w_l): + w_jitter = random.randint(0, jitter_range) + if (h != 0) and (h + tile_h != h_l): + h_jitter = random.randint(-jitter_range, jitter_range) + elif (h == 0) and (h + tile_h != h_l): + h_jitter = random.randint(-jitter_range, 0) + elif (h != 0) and (h + tile_h == h_l): + h_jitter = random.randint(0, jitter_range) + h +=(h_jitter + jitter_range) + w += (w_jitter + jitter_range) + + bbox = BBox(w, h, tile_w, tile_h) + bbox_list.append(bbox) + return bbox_list, None + + @grid_bbox + def get_views(self, overlap:int, tile_bs:int): + self.enable_grid_bbox = True + self.tile_w = self.window_size + self.tile_h = self.window_size + + self.overlap = max(0, min(overlap, self.window_size - 4)) + + self.stride = max(1,self.window_size - self.overlap) + + # split the latent into overlapped tiles, then batching + # weights basically indicate how many times a pixel is painted + bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights()) + print("BBOX:",len(bboxes)) + self.num_tiles = len(bboxes) + self.num_batches = math.ceil(self.num_tiles / tile_bs) + self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size + self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] + + global_bboxes = self.global_split_bboxes() + self.global_num_tiles = len(global_bboxes) + self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs) + self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) + self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] + + def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size, device=devices.device) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + + def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + + return blurred_latents + + + ''' ↓↓↓ kernel hijacks ↓↓↓ ''' + @torch.no_grad() + @keep_signature + def forward_one_step(self, x_in, sigma, **kwarg): + self.add_one() + if self.is_kdiff: + self.xi = self.p.x + self.p.noise * self.p.sigmas[self.p.current_step] + else: + alphas_cumprod = self.p.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) + self.xi = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod + + self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) + c2 = self.cosine_factor**self.p.cosine_scale_2 + + self.c1 = self.cosine_factor ** self.p.cosine_scale_1 + + self.x_in_tmp = x_in*(1 - self.c1) + self.xi * self.c1 + + if self.p.random_jitter: + jitter_range = self.jitter_range + else: + jitter_range = 0 + self.x_in_tmp_ = F.pad(self.x_in_tmp,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) + _,_,H,W = self.x_in_tmp.shape + + std_, mean_ = self.x_in_tmp.std(), self.x_in_tmp.mean() + c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 + latents_gaussian = self.gaussian_filter(self.x_in_tmp, kernel_size=(2*self.p.current_scale_num-1), sigma=0.8*c3) + self.latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_ + self.jitter_range = jitter_range + self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step_local + self.repeat_3 = False + x_local = self.sampler.model_wrap_cfg.forward_ori(self.x_in_tmp_,sigma, **kwarg) + self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward + x_local = x_local[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] + + ############################################# Dilated Sampling ############################################# + if not hasattr(self.p.sd_model, 'apply_model_ori'): + self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model + self.p.sd_model.apply_model = self.apply_model_hijack + x_global = torch.zeros_like(x_local) + + for batch_id, bboxes in enumerate(self.global_batched_bboxes): + for bbox in bboxes: + w,h = bbox + + ###### + + x_global_i = self.sampler.model_wrap_cfg.forward_ori(self.x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma, **kwarg) # x_in_tmp could be changed to latents_gaussian + x_global[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num] += x_global_i + + ###### + + #NOTE: Predicting Noise on Gaussian Latent and Obtaining Denoised on Original Latent + + # self.x_out_list = [] + # self.x_out_idx = -1 + # self.flag = 1 + # self.sampler.model_wrap_cfg.forward_ori(self.latents_gaussian[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma,**kwarg) + # self.flag = 0 + # x_global_i = self.sampler.model_wrap_cfg.forward_ori(self.x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma,**kwarg) + # x_global[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num] += x_global_i + + self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori + + x_out= x_local*(1-c2)+ x_global*c2 + return x_out + + + @torch.no_grad() + @keep_signature + def sample_one_step_local(self, x_in, sigma, cond): + assert LatentDiffusion.apply_model + def repeat_func_1(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor: + sigma_tile = self.repeat_tensor(sigma, len(bboxes)) + cond_tile = self.repeat_cond_dict(cond, bboxes) + return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) + + def repeat_func_2(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]: + n_rep = len(bboxes) + ts_tile = self.repeat_tensor(sigma, n_rep) + if isinstance(cond, dict): # FIXME: when will enter this branch? + cond_tile = self.repeat_cond_dict(cond, bboxes) + else: + cond_tile = self.repeat_tensor(cond, n_rep) + return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) + + def repeat_func_3(x_tile:Tensor, bboxes:List[CustomBBox]): + sigma_in_tile = sigma.repeat(len(bboxes)) + cond_out = self.repeat_cond_dict(cond, bboxes) + x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) + return x_tile_out + + if self.repeat_3: + repeat_func = repeat_func_3 + self.repeat_3 = False + elif self.is_kdiff: + repeat_func = repeat_func_1 + else: + repeat_func = repeat_func_2 + N,_,_,_ = x_in.shape + + H = self.h + W = self.w + + self.x_buffer = torch.zeros_like(x_in) + self.weights = torch.zeros_like(x_in) + for batch_id, bboxes in enumerate(self.batched_bboxes): + if state.interrupted: return x_in + x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) + x_tile_out = repeat_func(x_tile, bboxes) + # de-batching + for i, bbox in enumerate(bboxes): + self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] + self.weights[bbox.slicer] += 1 + self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode + + x_buffer = self.x_buffer/self.weights + + return x_buffer + + + + @torch.no_grad() + @keep_signature + def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict): + assert LatentDiffusion.apply_model + + x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) + return x_tile_out + #NOTE: Using Gaussian Latent to Predict Noise on the Original Latent + # if self.flag == 1: + # x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) + # self.x_out_list.append(x_tile_out) + # return x_tile_out + # else: + # self.x_out_idx += 1 + # return self.x_out_list[self.x_out_idx] + + + def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: + # NOTE: The following code is analytically wrong but aesthetically beautiful + cond_in_original = cond_in.copy() + + self.repeat_3 = True + + return self.sample_one_step_local(x_in, sigma_in, cond_in_original) diff --git a/tile_utils/utils.py b/tile_utils/utils.py index 59f1c67..4bf9b41 100644 --- a/tile_utils/utils.py +++ b/tile_utils/utils.py @@ -30,6 +30,9 @@ class Method(ComparableEnum): MULTI_DIFF = 'MultiDiffusion' MIX_DIFF = 'Mixture of Diffusers' +class Method_2(ComparableEnum): + DEMO_FU = "DemoFusion" + class BlendMode(Enum): # i.e. LayerType FOREGROUND = 'Foreground'