fix & update
parent
0103762918
commit
b629e087a0
|
|
@ -21,6 +21,20 @@ CFG_PATH = os.path.join(scripts.basedir(), 'region_configs')
|
|||
BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16)
|
||||
|
||||
|
||||
def create_infotext_hijack(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=-1, all_negative_prompts=None):
|
||||
idx = index
|
||||
if index == -1:
|
||||
idx = None
|
||||
text = processing.create_infotext_ori(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt, idx, all_negative_prompts)
|
||||
start_index = text.find("Size")
|
||||
if start_index != -1:
|
||||
r_text = f"Size:{p.width_list[index]}x{p.height_list[index]}"
|
||||
end_index = text.find(",", start_index)
|
||||
if end_index != -1:
|
||||
replaced_string = text[:start_index] + r_text + text[end_index:]
|
||||
return replaced_string
|
||||
return text
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
self.controlnet_script: ModuleType = None
|
||||
|
|
@ -42,14 +56,14 @@ class Script(scripts.Script):
|
|||
|
||||
with gr.Accordion('DemoFusion', open=False, elem_id=f'MD-{tab}'):
|
||||
with gr.Row(variant='compact') as tab_enable:
|
||||
enabled = gr.Checkbox(label='Enable DemoFusion(Do not open it with tilediffusion)', value=False, elem_id=uid('enabled'))
|
||||
random_jitter = gr.Checkbox(label='Random jitter', value = True, elem_id=uid('random-jitter'))
|
||||
enabled = gr.Checkbox(label='Enable DemoFusion(Dont open with tilediffusion)', value=False, elem_id=uid('enabled'))
|
||||
random_jitter = gr.Checkbox(label='Random Jitter', value = True, elem_id=uid('random-jitter'))
|
||||
gaussian_filter = gr.Checkbox(label='Gaussian Filter', value=True, visible=False, elem_id=uid('gaussian'))
|
||||
keep_input_size = gr.Checkbox(label='Keep input-image size', value=False,visible=is_img2img, elem_id=uid('keep-input-size'))
|
||||
gaussian_filter = gr.Checkbox(label='Gaussian filter', value=False, elem_id=uid('gaussian'))
|
||||
|
||||
|
||||
with gr.Row(variant='compact') as tab_param:
|
||||
method = gr.Dropdown(label='Method', choices=[Method_2.DEMO_FU.value], value=Method_2.DEMO_FU.value, elem_id=uid('method'))
|
||||
method = gr.Dropdown(label='Method', choices=[Method_2.DEMO_FU.value], value=Method_2.DEMO_FU.value, visible= False, elem_id=uid('method'))
|
||||
control_tensor_cpu = gr.Checkbox(label='Move ControlNet tensor to CPU (if applicable)', value=False, elem_id=uid('control-tensor-cpu'))
|
||||
reset_status = gr.Button(value='Free GPU', variant='tool')
|
||||
reset_status.click(fn=self.reset_and_gc, show_progress=False)
|
||||
|
|
@ -62,13 +76,13 @@ class Script(scripts.Script):
|
|||
overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent window overlap', value=64, elem_id=uid('latent-tile-overlap'))
|
||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Latent window batch size', value=4, elem_id=uid('latent-tile-batch-size'))
|
||||
with gr.Row(variant='compact', visible=True) as tab_c:
|
||||
c1 = gr.Slider(minimum=0, maximum=5, step=0.1, label='C1', value=3, elem_id=f'C1-{tab}')
|
||||
c2 = gr.Slider(minimum=0, maximum=5, step=0.1, label='C2', value=1, elem_id=f'C2-{tab}')
|
||||
c3 = gr.Slider(minimum=0, maximum=5, step=0.1, label='C3', value=1, elem_id=f'C3-{tab}')
|
||||
c1 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 1', value=3, elem_id=f'C1-{tab}')
|
||||
c2 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 2', value=1, elem_id=f'C2-{tab}')
|
||||
c3 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 3', value=1, elem_id=f'C3-{tab}')
|
||||
sigma = gr.Slider(minimum=0, maximum=1, step=0.01, label='Sigma', value=0.4, elem_id=f'Sigma-{tab}')
|
||||
with gr.Group() as tab_denoise:
|
||||
strength = gr.Slider(minimum=0, maximum=1, step=0.01, value = 0.9, label='Denoising strength substage',visible= not is_img2img, elem_id=f'strength-{tab}')
|
||||
strength = gr.Slider(minimum=0, maximum=1, step=0.01, value = 0.9,label='Denoising Strength for Substage',visible=not is_img2img, elem_id=f'strength-{tab}')
|
||||
with gr.Row(variant='compact') as tab_upscale:
|
||||
|
||||
scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=1, label='Scale Factor', value=2.0, elem_id=uid('upscaler-factor'))
|
||||
|
||||
|
||||
|
|
@ -92,7 +106,7 @@ class Script(scripts.Script):
|
|||
noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel,
|
||||
control_tensor_cpu,
|
||||
random_jitter,
|
||||
c1,c2,c3,gaussian_filter,strength
|
||||
c1,c2,c3,gaussian_filter,strength,sigma
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -104,7 +118,7 @@ class Script(scripts.Script):
|
|||
noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int,
|
||||
control_tensor_cpu: bool,
|
||||
random_jitter:bool,
|
||||
c1,c2,c3,gaussian_filter,strength
|
||||
c1,c2,c3,gaussian_filter,strength,sigma
|
||||
):
|
||||
|
||||
# unhijack & unhook, in case it broke at last time
|
||||
|
|
@ -191,7 +205,14 @@ class Script(scripts.Script):
|
|||
|
||||
p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack(
|
||||
conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img,
|
||||
window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength)
|
||||
window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma)
|
||||
|
||||
processing.create_infotext_ori = processing.create_infotext
|
||||
|
||||
p.width_list = [p.height]
|
||||
p.height_list = [p.height]
|
||||
|
||||
processing.create_infotext = create_infotext_hijack
|
||||
## end
|
||||
|
||||
|
||||
|
|
@ -200,6 +221,20 @@ class Script(scripts.Script):
|
|||
|
||||
if self.delegate is not None: self.delegate.reset_controlnet_tensors()
|
||||
|
||||
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
||||
for idx,image in enumerate(pp.images):
|
||||
idx_b = idx//p.batch_size
|
||||
pp.images[idx] = image[:,:image.shape[1]//(p.scale_factor)*(idx_b+1),:image.shape[2]//(p.scale_factor)*(idx_b+1)]
|
||||
p.seeds = [item for _ in range(p.scale_factor) for item in p.seeds]
|
||||
p.prompts = [item for _ in range(p.scale_factor) for item in p.prompts]
|
||||
p.all_negative_prompts = [item for _ in range(p.scale_factor) for item in p.all_negative_prompts]
|
||||
p.negative_prompts = [item for _ in range(p.scale_factor) for item in p.negative_prompts]
|
||||
if p.color_corrections != None:
|
||||
p.color_corrections = [item for _ in range(p.scale_factor) for item in p.color_corrections]
|
||||
p.width_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.width for _ in range(p.batch_size)]]
|
||||
p.height_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.height for _ in range(p.batch_size)]]
|
||||
return
|
||||
|
||||
def postprocess(self, p: Processing, processed, enabled, *args):
|
||||
if not enabled: return
|
||||
# unhijack & unhook
|
||||
|
|
@ -219,8 +254,10 @@ class Script(scripts.Script):
|
|||
|
||||
''' ↓↓↓ inner API hijack ↓↓↓ '''
|
||||
@torch.no_grad()
|
||||
def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength):
|
||||
def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma):
|
||||
################################################## Phase Initialization ######################################################
|
||||
# p.width = p.width_original_md
|
||||
# p.height = p.height_original_md
|
||||
|
||||
if not image_ori:
|
||||
p.current_step = 0
|
||||
|
|
@ -230,7 +267,10 @@ class Script(scripts.Script):
|
|||
|
||||
p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) #scale
|
||||
x = p.rng.next()
|
||||
print("### Phase 1 Denoising ###")
|
||||
latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x))
|
||||
latents_ = F.pad(latents, (0, latents.shape[3]*(p.scale_factor-1), 0, latents.shape[2]*(p.scale_factor-1)))
|
||||
res = latents_
|
||||
del x
|
||||
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
||||
starting_scale = 2
|
||||
|
|
@ -249,6 +289,7 @@ class Script(scripts.Script):
|
|||
p.cosine_scale_1 = c1
|
||||
p.cosine_scale_2 = c2
|
||||
p.cosine_scale_3 = c3
|
||||
self.delegate.sig = sigma
|
||||
p.latents = latents
|
||||
for current_scale_num in range(starting_scale, p.scale_factor+1):
|
||||
p.current_scale_num = current_scale_num
|
||||
|
|
@ -268,7 +309,7 @@ class Script(scripts.Script):
|
|||
|
||||
info = ', '.join([
|
||||
# f"{method.value} hooked into {name!r} sampler",
|
||||
f"Tile size: {window_size}",
|
||||
f"Tile size: {self.delegate.window_size}",
|
||||
f"Tile count: {self.delegate.num_tiles}",
|
||||
f"Batch size: {self.delegate.tile_bs}",
|
||||
f"Tile batches: {len(self.delegate.batched_bboxes)}",
|
||||
|
|
@ -294,10 +335,17 @@ class Script(scripts.Script):
|
|||
self.flag_noise_inverse = False
|
||||
|
||||
p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean
|
||||
latents_ = F.pad(p.latents, (0, p.latents.shape[3]//current_scale_num*(p.scale_factor-current_scale_num), 0, p.latents.shape[2]//current_scale_num*(p.scale_factor-current_scale_num)))
|
||||
if current_scale_num==1:
|
||||
res = latents_
|
||||
else:
|
||||
res = torch.concatenate((res,latents_),axis=0)
|
||||
|
||||
#########################################################################################################################################
|
||||
p.width = p.width*p.scale_factor
|
||||
p.height = p.height*p.scale_factor
|
||||
return p.latents
|
||||
# p.width = p.width*p.scale_factor
|
||||
# p.height = p.height*p.scale_factor
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def callback_hijack(self_sampler,d,p):
|
||||
|
|
@ -341,7 +389,7 @@ class Script(scripts.Script):
|
|||
else: raise NotImplementedError(f"Method {method} not implemented.")
|
||||
|
||||
delegate = delegate_cls(p, sampler)
|
||||
delegate.window_size = window_size
|
||||
delegate.window_size = min(min(window_size,p.width//8),p.height//8)
|
||||
p.random_jitter = random_jitter
|
||||
|
||||
if flag_noise_inverse:
|
||||
|
|
@ -363,7 +411,7 @@ class Script(scripts.Script):
|
|||
|
||||
info = ', '.join([
|
||||
f"{method.value} hooked into {name!r} sampler",
|
||||
f"Tile size: {window_size}",
|
||||
f"Tile size: {delegate.window_size}",
|
||||
f"Tile count: {delegate.num_tiles}",
|
||||
f"Batch size: {delegate.tile_bs}",
|
||||
f"Tile batches: {len(delegate.batched_bboxes)}",
|
||||
|
|
@ -484,6 +532,9 @@ class Script(scripts.Script):
|
|||
if hasattr(sd_samplers_common.Sampler, "callback_ori"):
|
||||
sd_samplers_common.Sampler.callback_state = sd_samplers_common.Sampler.callback_ori
|
||||
del sd_samplers_common.Sampler.callback_ori
|
||||
if hasattr(processing, "create_infotext_ori"):
|
||||
processing.create_infotext = processing.create_infotext_ori
|
||||
del processing.create_infotext_ori
|
||||
DemoFusion.unhook()
|
||||
self.delegate = None
|
||||
|
||||
|
|
|
|||
|
|
@ -20,18 +20,16 @@ class DemoFusion(AbstractDiffusion):
|
|||
|
||||
def hook(self):
|
||||
steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None)
|
||||
# print("ENC",self.t_enc)
|
||||
|
||||
self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward
|
||||
self.sampler.model_wrap_cfg.forward = self.forward_one_step
|
||||
self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward
|
||||
self.sampler.model_wrap_cfg.forward = self.forward_one_step
|
||||
if self.is_kdiff:
|
||||
self.sampler: KDiffusionSampler
|
||||
self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion
|
||||
self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser]
|
||||
sigmas = self.sampler.get_sigmas(self.p, steps)
|
||||
# print("SIGMAS:",sigmas)
|
||||
self.p.sigmas = sigmas[steps - self.t_enc - 1:]
|
||||
# sigmas = self.sampler.get_sigmas(self.p, steps)
|
||||
# self.p.sigmas = sigmas[steps - self.t_enc - 1:]
|
||||
else:
|
||||
self.sampler: CompVisSampler
|
||||
self.sampler.model_wrap_cfg: CFGDenoiserTimesteps
|
||||
|
|
@ -105,12 +103,13 @@ class DemoFusion(AbstractDiffusion):
|
|||
dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0
|
||||
dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0
|
||||
bbox_list: List[BBox] = []
|
||||
self.jitter_range = 0
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
h = min(int(row * dy), h_l - tile_h)
|
||||
w = min(int(col * dx), w_l - tile_w)
|
||||
if self.p.random_jitter:
|
||||
self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),int(self.stride/2))
|
||||
self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2)))
|
||||
jitter_range = self.jitter_range
|
||||
w_jitter = 0
|
||||
h_jitter = 0
|
||||
|
|
@ -141,12 +140,11 @@ class DemoFusion(AbstractDiffusion):
|
|||
|
||||
self.overlap = max(0, min(overlap, self.window_size - 4))
|
||||
|
||||
self.stride = max(1,self.window_size - self.overlap)
|
||||
self.stride = max(4,self.window_size - self.overlap)
|
||||
|
||||
# split the latent into overlapped tiles, then batching
|
||||
# weights basically indicate how many times a pixel is painted
|
||||
bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
|
||||
print("BBOX:",len(bboxes))
|
||||
bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights())
|
||||
self.num_tiles = len(bboxes)
|
||||
self.num_batches = math.ceil(self.num_tiles / tile_bs)
|
||||
self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size
|
||||
|
|
@ -181,77 +179,39 @@ class DemoFusion(AbstractDiffusion):
|
|||
@keep_signature
|
||||
def forward_one_step(self, x_in, sigma, **kwarg):
|
||||
if self.is_kdiff:
|
||||
self.xi = self.p.x + self.p.noise * self.p.sigmas[self.p.current_step]
|
||||
x_noisy = self.p.x + self.p.noise * sigma[0]
|
||||
else:
|
||||
alphas_cumprod = self.p.sd_model.alphas_cumprod
|
||||
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]])
|
||||
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]])
|
||||
self.xi = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod
|
||||
x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod
|
||||
|
||||
self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1)))))
|
||||
c2 = self.cosine_factor**self.p.cosine_scale_2
|
||||
|
||||
self.c1 = self.cosine_factor ** self.p.cosine_scale_1
|
||||
c1 = self.cosine_factor ** self.p.cosine_scale_1
|
||||
|
||||
x_in_tmp = x_in*(1 - self.c1) + self.xi * self.c1
|
||||
x_in = x_in*(1 - c1) + x_noisy * c1
|
||||
|
||||
if self.p.random_jitter:
|
||||
jitter_range = self.jitter_range
|
||||
else:
|
||||
jitter_range = 0
|
||||
x_in_tmp_ = F.pad(x_in_tmp,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0)
|
||||
_,_,H,W = x_in_tmp.shape
|
||||
x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0)
|
||||
_,_,H,W = x_in.shape
|
||||
|
||||
std_, mean_ = x_in_tmp.std(), x_in_tmp.mean()
|
||||
c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2
|
||||
latents_gaussian = self.gaussian_filter(x_in_tmp, kernel_size=(2*self.p.current_scale_num-1), sigma=0.8*c3)
|
||||
latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_
|
||||
self.jitter_range = jitter_range
|
||||
self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step_local
|
||||
self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step
|
||||
self.repeat_3 = False
|
||||
|
||||
x_local = self.sampler.model_wrap_cfg.forward_ori(x_in_tmp_,sigma, **kwarg)
|
||||
x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg)
|
||||
self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward
|
||||
x_local = x_local[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W]
|
||||
x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W]
|
||||
|
||||
############################################# Dilated Sampling #############################################
|
||||
if not hasattr(self.p.sd_model, 'apply_model_ori'):
|
||||
self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model
|
||||
self.p.sd_model.apply_model = self.apply_model_hijack
|
||||
x_global = torch.zeros_like(x_local)
|
||||
|
||||
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
|
||||
for bbox in bboxes:
|
||||
w,h = bbox
|
||||
|
||||
######
|
||||
if self.gaussian_filter:
|
||||
x_global_i = self.sampler.model_wrap_cfg.forward_ori(latents_gaussian[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma, **kwarg)
|
||||
else:
|
||||
x_global_i = self.sampler.model_wrap_cfg.forward_ori(x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma, **kwarg) # x_in_tmp could be changed to latents_gaussian
|
||||
x_global[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num] += x_global_i
|
||||
|
||||
######
|
||||
|
||||
#NOTE: Predicting Noise on Gaussian Latent and Obtaining Denoised on Original Latent
|
||||
|
||||
# self.x_out_list = []
|
||||
# self.x_out_idx = -1
|
||||
# self.flag = 1
|
||||
# self.sampler.model_wrap_cfg.forward_ori(x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma,**kwarg)
|
||||
# self.flag = 0
|
||||
# x_global_i = self.sampler.model_wrap_cfg.forward_ori(x_in_tmp[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num],sigma,**kwarg)
|
||||
# x_global[:,:,h::self.p.current_scale_num,w::self.p.current_scale_num] += x_global_i
|
||||
|
||||
self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori
|
||||
|
||||
x_out= x_local*(1-c2)+ x_global*c2
|
||||
return x_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@keep_signature
|
||||
def sample_one_step_local(self, x_in, sigma, cond):
|
||||
def sample_one_step(self, x_in, sigma, cond):
|
||||
assert LatentDiffusion.apply_model
|
||||
def repeat_func_1(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor:
|
||||
sigma_tile = self.repeat_tensor(sigma, len(bboxes))
|
||||
|
|
@ -282,11 +242,12 @@ class DemoFusion(AbstractDiffusion):
|
|||
repeat_func = repeat_func_2
|
||||
N,_,_,_ = x_in.shape
|
||||
|
||||
H = self.h
|
||||
W = self.w
|
||||
# H = self.h
|
||||
# W = self.w
|
||||
|
||||
self.x_buffer = torch.zeros_like(x_in)
|
||||
self.weights = torch.zeros_like(x_in)
|
||||
|
||||
for batch_id, bboxes in enumerate(self.batched_bboxes):
|
||||
if state.interrupted: return x_in
|
||||
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0)
|
||||
|
|
@ -297,9 +258,44 @@ class DemoFusion(AbstractDiffusion):
|
|||
self.weights[bbox.slicer] += 1
|
||||
self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode
|
||||
|
||||
x_buffer = self.x_buffer/self.weights
|
||||
x_local = self.x_buffer/self.weights
|
||||
|
||||
return x_buffer
|
||||
self.x_buffer = torch.zeros_like(self.x_buffer)
|
||||
self.weights = torch.zeros_like(self.weights)
|
||||
|
||||
std_, mean_ = x_in.std(), x_in.mean()
|
||||
c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2
|
||||
if self.p.gaussian_filter:
|
||||
x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3)
|
||||
x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_
|
||||
|
||||
if not hasattr(self.p.sd_model, 'apply_model_ori'):
|
||||
self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model
|
||||
self.p.sd_model.apply_model = self.apply_model_hijack
|
||||
x_global = torch.zeros_like(x_local)
|
||||
jitter_range = self.jitter_range
|
||||
end = x_global.shape[3]-jitter_range
|
||||
|
||||
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
|
||||
for bbox in bboxes:
|
||||
w,h = bbox
|
||||
# self.x_out_list = []
|
||||
# self.x_out_idx = -1
|
||||
# self.flag = 1
|
||||
x_global_i0 = self.sampler_forward(x_in_g[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num],sigma,cond = cond)
|
||||
# self.flag = 0
|
||||
x_global_i1 = self.sampler_forward(x_in[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num],sigma,cond = cond) #NOTE According to the original execution process, it would be very strange to use the predicted noise of gaussian latents to predict the denoised data in non Gaussian latents. Why?
|
||||
self.x_buffer[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num] += (x_global_i0 + x_global_i1)/2
|
||||
self.weights[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num] += 1
|
||||
|
||||
self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori
|
||||
self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode
|
||||
|
||||
x_global = self.x_buffer/self.weights
|
||||
c2 = self.cosine_factor**self.p.cosine_scale_2
|
||||
self.x_buffer= x_local*(1-c2)+ x_global*c2
|
||||
|
||||
return self.x_buffer
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue