fix & update

pull/371/head
Jaylen Lee 2024-03-27 05:27:22 +08:00
parent 0103762918
commit b629e087a0
6 changed files with 134 additions and 87 deletions

View File

@ -21,6 +21,20 @@ CFG_PATH = os.path.join(scripts.basedir(), 'region_configs')
BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16)
def create_infotext_hijack(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=-1, all_negative_prompts=None):
idx = index
if index == -1:
idx = None
text = processing.create_infotext_ori(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt, idx, all_negative_prompts)
start_index = text.find("Size")
if start_index != -1:
r_text = f"Size:{p.width_list[index]}x{p.height_list[index]}"
end_index = text.find(",", start_index)
if end_index != -1:
replaced_string = text[:start_index] + r_text + text[end_index:]
return replaced_string
return text
class Script(scripts.Script):
def __init__(self):
self.controlnet_script: ModuleType = None
@ -42,14 +56,14 @@ class Script(scripts.Script):
with gr.Accordion('DemoFusion', open=False, elem_id=f'MD-{tab}'):
with gr.Row(variant='compact') as tab_enable:
enabled = gr.Checkbox(label='Enable DemoFusion(Do not open it with tilediffusion)', value=False, elem_id=uid('enabled'))
random_jitter = gr.Checkbox(label='Random jitter', value = True, elem_id=uid('random-jitter'))
enabled = gr.Checkbox(label='Enable DemoFusion(Dont open with tilediffusion)', value=False, elem_id=uid('enabled'))
random_jitter = gr.Checkbox(label='Random Jitter', value = True, elem_id=uid('random-jitter'))
gaussian_filter = gr.Checkbox(label='Gaussian Filter', value=True, visible=False, elem_id=uid('gaussian'))
keep_input_size = gr.Checkbox(label='Keep input-image size', value=False,visible=is_img2img, elem_id=uid('keep-input-size'))
gaussian_filter = gr.Checkbox(label='Gaussian filter', value=False, elem_id=uid('gaussian'))
with gr.Row(variant='compact') as tab_param:
method = gr.Dropdown(label='Method', choices=[Method_2.DEMO_FU.value], value=Method_2.DEMO_FU.value, elem_id=uid('method'))
method = gr.Dropdown(label='Method', choices=[Method_2.DEMO_FU.value], value=Method_2.DEMO_FU.value, visible= False, elem_id=uid('method'))
control_tensor_cpu = gr.Checkbox(label='Move ControlNet tensor to CPU (if applicable)', value=False, elem_id=uid('control-tensor-cpu'))
reset_status = gr.Button(value='Free GPU', variant='tool')
reset_status.click(fn=self.reset_and_gc, show_progress=False)
@ -62,13 +76,13 @@ class Script(scripts.Script):
overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent window overlap', value=64, elem_id=uid('latent-tile-overlap'))
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Latent window batch size', value=4, elem_id=uid('latent-tile-batch-size'))
with gr.Row(variant='compact', visible=True) as tab_c:
c1 = gr.Slider(minimum=0, maximum=5, step=0.1, label='C1', value=3, elem_id=f'C1-{tab}')
c2 = gr.Slider(minimum=0, maximum=5, step=0.1, label='C2', value=1, elem_id=f'C2-{tab}')
c3 = gr.Slider(minimum=0, maximum=5, step=0.1, label='C3', value=1, elem_id=f'C3-{tab}')
c1 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 1', value=3, elem_id=f'C1-{tab}')
c2 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 2', value=1, elem_id=f'C2-{tab}')
c3 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 3', value=1, elem_id=f'C3-{tab}')
sigma = gr.Slider(minimum=0, maximum=1, step=0.01, label='Sigma', value=0.4, elem_id=f'Sigma-{tab}')
with gr.Group() as tab_denoise:
strength = gr.Slider(minimum=0, maximum=1, step=0.01, value = 0.9, label='Denoising strength substage',visible= not is_img2img, elem_id=f'strength-{tab}')
strength = gr.Slider(minimum=0, maximum=1, step=0.01, value = 0.9,label='Denoising Strength for Substage',visible=not is_img2img, elem_id=f'strength-{tab}')
with gr.Row(variant='compact') as tab_upscale:
scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=1, label='Scale Factor', value=2.0, elem_id=uid('upscaler-factor'))
@ -92,7 +106,7 @@ class Script(scripts.Script):
noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel,
control_tensor_cpu,
random_jitter,
c1,c2,c3,gaussian_filter,strength
c1,c2,c3,gaussian_filter,strength,sigma
]
@ -104,7 +118,7 @@ class Script(scripts.Script):
noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int,
control_tensor_cpu: bool,
random_jitter:bool,
c1,c2,c3,gaussian_filter,strength
c1,c2,c3,gaussian_filter,strength,sigma
):
# unhijack & unhook, in case it broke at last time
@ -191,7 +205,14 @@ class Script(scripts.Script):
p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack(
conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img,
window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength)
window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma)
processing.create_infotext_ori = processing.create_infotext
p.width_list = [p.height]
p.height_list = [p.height]
processing.create_infotext = create_infotext_hijack
## end
@ -200,6 +221,20 @@ class Script(scripts.Script):
if self.delegate is not None: self.delegate.reset_controlnet_tensors()
def postprocess_batch_list(self, p, pp, *args, **kwargs):
for idx,image in enumerate(pp.images):
idx_b = idx//p.batch_size
pp.images[idx] = image[:,:image.shape[1]//(p.scale_factor)*(idx_b+1),:image.shape[2]//(p.scale_factor)*(idx_b+1)]
p.seeds = [item for _ in range(p.scale_factor) for item in p.seeds]
p.prompts = [item for _ in range(p.scale_factor) for item in p.prompts]
p.all_negative_prompts = [item for _ in range(p.scale_factor) for item in p.all_negative_prompts]
p.negative_prompts = [item for _ in range(p.scale_factor) for item in p.negative_prompts]
if p.color_corrections != None:
p.color_corrections = [item for _ in range(p.scale_factor) for item in p.color_corrections]
p.width_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.width for _ in range(p.batch_size)]]
p.height_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.height for _ in range(p.batch_size)]]
return
def postprocess(self, p: Processing, processed, enabled, *args):
if not enabled: return
# unhijack & unhook
@ -219,8 +254,10 @@ class Script(scripts.Script):
''' ↓↓↓ inner API hijack ↓↓↓ '''
@torch.no_grad()
def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength):
def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma):
################################################## Phase Initialization ######################################################
# p.width = p.width_original_md
# p.height = p.height_original_md
if not image_ori:
p.current_step = 0
@ -230,7 +267,10 @@ class Script(scripts.Script):
p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) #scale
x = p.rng.next()
print("### Phase 1 Denoising ###")
latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x))
latents_ = F.pad(latents, (0, latents.shape[3]*(p.scale_factor-1), 0, latents.shape[2]*(p.scale_factor-1)))
res = latents_
del x
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
starting_scale = 2
@ -249,6 +289,7 @@ class Script(scripts.Script):
p.cosine_scale_1 = c1
p.cosine_scale_2 = c2
p.cosine_scale_3 = c3
self.delegate.sig = sigma
p.latents = latents
for current_scale_num in range(starting_scale, p.scale_factor+1):
p.current_scale_num = current_scale_num
@ -268,7 +309,7 @@ class Script(scripts.Script):
info = ', '.join([
# f"{method.value} hooked into {name!r} sampler",
f"Tile size: {window_size}",
f"Tile size: {self.delegate.window_size}",
f"Tile count: {self.delegate.num_tiles}",
f"Batch size: {self.delegate.tile_bs}",
f"Tile batches: {len(self.delegate.batched_bboxes)}",
@ -294,10 +335,17 @@ class Script(scripts.Script):
self.flag_noise_inverse = False
p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean
latents_ = F.pad(p.latents, (0, p.latents.shape[3]//current_scale_num*(p.scale_factor-current_scale_num), 0, p.latents.shape[2]//current_scale_num*(p.scale_factor-current_scale_num)))
if current_scale_num==1:
res = latents_
else:
res = torch.concatenate((res,latents_),axis=0)
#########################################################################################################################################
p.width = p.width*p.scale_factor
p.height = p.height*p.scale_factor
return p.latents
# p.width = p.width*p.scale_factor
# p.height = p.height*p.scale_factor
return res
@staticmethod
def callback_hijack(self_sampler,d,p):
@ -341,7 +389,7 @@ class Script(scripts.Script):
else: raise NotImplementedError(f"Method {method} not implemented.")
delegate = delegate_cls(p, sampler)
delegate.window_size = window_size
delegate.window_size = min(min(window_size,p.width//8),p.height//8)
p.random_jitter = random_jitter
if flag_noise_inverse:
@ -363,7 +411,7 @@ class Script(scripts.Script):
info = ', '.join([
f"{method.value} hooked into {name!r} sampler",
f"Tile size: {window_size}",
f"Tile size: {delegate.window_size}",
f"Tile count: {delegate.num_tiles}",
f"Batch size: {delegate.tile_bs}",
f"Tile batches: {len(delegate.batched_bboxes)}",
@ -484,6 +532,9 @@ class Script(scripts.Script):
if hasattr(sd_samplers_common.Sampler, "callback_ori"):
sd_samplers_common.Sampler.callback_state = sd_samplers_common.Sampler.callback_ori
del sd_samplers_common.Sampler.callback_ori
if hasattr(processing, "create_infotext_ori"):
processing.create_infotext = processing.create_infotext_ori
del processing.create_infotext_ori
DemoFusion.unhook()
self.delegate = None

View File

@ -20,18 +20,16 @@ class DemoFusion(AbstractDiffusion):
def hook(self):
steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None)
# print("ENC",self.t_enc)
self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward
self.sampler.model_wrap_cfg.forward = self.forward_one_step
self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward
self.sampler.model_wrap_cfg.forward = self.forward_one_step
if self.is_kdiff:
self.sampler: KDiffusionSampler
self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion
self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser]
sigmas = self.sampler.get_sigmas(self.p, steps)
# print("SIGMAS:",sigmas)
self.p.sigmas = sigmas[steps - self.t_enc - 1:]
# sigmas = self.sampler.get_sigmas(self.p, steps)
# self.p.sigmas = sigmas[steps - self.t_enc - 1:]
else:
self.sampler: CompVisSampler
self.sampler.model_wrap_cfg: CFGDenoiserTimesteps
@ -105,12 +103,13 @@ class DemoFusion(AbstractDiffusion):
dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0
bbox_list: List[BBox] = []
self.jitter_range = 0
for row in range(rows):
for col in range(cols):
h = min(int(row * dy), h_l - tile_h)
w = min(int(col * dx), w_l - tile_w)
if self.p.random_jitter:
self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),int(self.stride/2))
self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2)))
jitter_range = self.jitter_range
w_jitter = 0
h_jitter = 0
@ -141,12 +140,11 @@ class DemoFusion(AbstractDiffusion):
self.overlap = max(0, min(overlap, self.window_size - 4))
self.stride = max(1,self.window_size - self.overlap)
self.stride = max(4,self.window_size - self.overlap)
# split the latent into overlapped tiles, then batching
# weights basically indicate how many times a pixel is painted
bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
print("BBOX:",len(bboxes))
bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights())
self.num_tiles = len(bboxes)
self.num_batches = math.ceil(self.num_tiles / tile_bs)
self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size
@ -181,77 +179,39 @@ class DemoFusion(AbstractDiffusion):
@keep_signature
def forward_one_step(self, x_in, sigma, **kwarg):
if self.is_kdiff:
self.xi = self.p.x + self.p.noise * self.p.sigmas[self.p.current_step]
x_noisy = self.p.x + self.p.noise * sigma[0]
else:
alphas_cumprod = self.p.sd_model.alphas_cumprod
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]])
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]])
self.xi = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod
x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod
self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1)))))
c2 = self.cosine_factor**self.p.cosine_scale_2
self.c1 = self.cosine_factor ** self.p.cosine_scale_1
c1 = self.cosine_factor ** self.p.cosine_scale_1
x_in_tmp = x_in*(1 - self.c1) + self.xi * self.c1
x_in = x_in*(1 - c1) + x_noisy * c1
if self.p.random_jitter:
jitter_range = self.jitter_range
else:
jitter_range = 0
x_in_tmp_ = F.pad(x_in_tmp,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0)
_,_,H,W = x_in_tmp.shape
x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0)
_,_,H,W = x_in.shape
std_, mean_ = x_in_tmp.std(), x_in_tmp.mean()
c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2
latents_gaussian = self.gaussian_filter(x_in_tmp, kernel_size=(2*self.p.current_scale_num-1), sigma=0.8*c3)
latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_
self.jitter_range = jitter_range
self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step_local
self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step
self.repeat_3 = False
x_local = self.sampler.model_wrap_cfg.forward_ori(x_in_tmp_,sigma, **kwarg)
x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg)
self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward
x_local = x_local[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W]
x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W]
############################################# Dilated Sampling #############################################
if not hasattr(self.p.sd_model, 'apply_model_ori'):
self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model
self.p.sd_model.apply_model = self.apply_model_hijack
x_global = torch.zeros_like(x_local)
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
for bbox in bboxes:
w,h = bbox
######
if self.gaussian_filter:
x_global_i = self.sampler.model_wrap_cfg.forward_ori(latents_gaussian[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma, **kwarg)
else:
x_global_i = self.sampler.model_wrap_cfg.forward_ori(x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma, **kwarg) # x_in_tmp could be changed to latents_gaussian
x_global[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num] += x_global_i
######
#NOTE: Predicting Noise on Gaussian Latent and Obtaining Denoised on Original Latent
# self.x_out_list = []
# self.x_out_idx = -1
# self.flag = 1
# self.sampler.model_wrap_cfg.forward_ori(x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma,**kwarg)
# self.flag = 0
# x_global_i = self.sampler.model_wrap_cfg.forward_ori(x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma,**kwarg)
# x_global[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num] += x_global_i
self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori
x_out= x_local*(1-c2)+ x_global*c2
return x_out
@torch.no_grad()
@keep_signature
def sample_one_step_local(self, x_in, sigma, cond):
def sample_one_step(self, x_in, sigma, cond):
assert LatentDiffusion.apply_model
def repeat_func_1(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor:
sigma_tile = self.repeat_tensor(sigma, len(bboxes))
@ -282,11 +242,12 @@ class DemoFusion(AbstractDiffusion):
repeat_func = repeat_func_2
N,_,_,_ = x_in.shape
H = self.h
W = self.w
# H = self.h
# W = self.w
self.x_buffer = torch.zeros_like(x_in)
self.weights = torch.zeros_like(x_in)
for batch_id, bboxes in enumerate(self.batched_bboxes):
if state.interrupted: return x_in
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0)
@ -297,9 +258,44 @@ class DemoFusion(AbstractDiffusion):
self.weights[bbox.slicer] += 1
self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode
x_buffer = self.x_buffer/self.weights
x_local = self.x_buffer/self.weights
return x_buffer
self.x_buffer = torch.zeros_like(self.x_buffer)
self.weights = torch.zeros_like(self.weights)
std_, mean_ = x_in.std(), x_in.mean()
c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2
if self.p.gaussian_filter:
x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3)
x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_
if not hasattr(self.p.sd_model, 'apply_model_ori'):
self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model
self.p.sd_model.apply_model = self.apply_model_hijack
x_global = torch.zeros_like(x_local)
jitter_range = self.jitter_range
end = x_global.shape[3]-jitter_range
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
for bbox in bboxes:
w,h = bbox
# self.x_out_list = []
# self.x_out_idx = -1
# self.flag = 1
x_global_i0 = self.sampler_forward(x_in_g[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num],sigma,cond = cond)
# self.flag = 0
x_global_i1 = self.sampler_forward(x_in[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num],sigma,cond = cond) #NOTE According to the original execution process, it would be very strange to use the predicted noise of gaussian latents to predict the denoised data in non Gaussian latents. Why?
self.x_buffer[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num] += (x_global_i0 + x_global_i1)/2
self.weights[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num] += 1
self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori
self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode
x_global = self.x_buffer/self.weights
c2 = self.cosine_factor**self.p.cosine_scale_2
self.x_buffer= x_local*(1-c2)+ x_global*c2
return self.x_buffer