Merge pull request #376 from Jaylen-Lee/main

Improving efficiency for demo& fix disable error
pull/396/head
Kahsolt 2024-03-30 16:08:58 +08:00 committed by GitHub
commit d29227bf70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 63 deletions

View File

@ -195,23 +195,25 @@ The extension enables **large image drawing & upscaling with limited VRAM** via
If you set the image size to 512 * 512, the appropriate window size and overlap are 64 and 32 or smaller. If it is 1024, it is recommended to double it, and so on.
Recommend using a higher denoising strength in img2img, maybe 0.8-1and try to use the original model, seeds, and prompt as much as possible
Recommend using a higher denoising strength in img2img, and try to use the original model, seeds, and prompt as much as possible
Do not enable it together with tilediffusion. It supports operations such as tilevae, noise inversion, etc.
Due to differences in implementation details, parameters such as c1, c2, c3 and sigma can refer to the [demofusion](https://ruoyidu.github.io/demofusion/demofusion.html), but may not be entirely effective sometimes. If there are blurred images, it is recommended to increase c3 and reduce Sigma.
There is a slight difference in the results of Mixture mode, but the inference time of UNet will increase by about 50%. It is not recommended to enable it under normal circumstances.
![demo-example](https://github.com/Jaylen-Lee/image-demo/blob/main/example.png?raw=true)
#### Example: txt2img, 1024 * 1024 image 3x upscale
- Params
- Step=45, sampler=Euler, same prompt as official demofusion, random seed 35, model SDXL1.0
- Step=45, sampler=Euler, same prompt as official demofusion, model SDXL1.0
- Denoising Strength for Substage = 0.85 Sigma=0.5use default values for the restenable tilevae
- Sigma=0.5use default values for the restenable tilevae
- Device 4060ti 16GB
- Device: 4060ti 16GB
- The following are the images obtained at a resolution of 1024, 2048, and 3072
![demo-result](https://github.com/Jaylen-Lee/image-demo/blob/main/3.png?raw=true)

View File

@ -190,20 +190,22 @@
如果你设定的生图大小是512*512 合适的window size和overlap是64和32或者更小。如果是1024则推荐翻倍以此类推
img2img推荐使用较高的重绘幅度,比如0.8-1并尽可能地使用原图的生图模型、随机种子以及prompt等
img2img请尽可能地使用较高的重绘幅度以及原图的生图模型、随机种子以及prompt等
不要同时开启tilediffusion. 但该组件支持tilevae、noise inversion等常用功能
由于实现细节的差异c1c2c3等参数可以参考[demofusion](https://ruoyidu.github.io/demofusion/demofusion.html)但不一定完全奏效. 如果出现模糊图像建议提高c3并降低Sigma
Mixture mode结果略有差异但unet推理时间会额外增加约50%,正常情况下不建议开启
![demo-example](https://github.com/Jaylen-Lee/image-demo/blob/main/example.png?raw=true)
#### 示例 txt2img 1024*1024图片 3倍放大
- 参数:
- 步数 = 45采样器 = Euler与demofusion示例同样的提示语随机种子35模型为SDXL1.0
- 子阶段降噪 = 0.85 Sigma=0.5其余采用默认值开启tilevae
- 设备 4060ti 16GB
- 步数 = 45采样器 = Euler与demofusion示例同样的提示语模型为SDXL1.0
- Sigma=0.5其余采用默认值开启tilevae
- 设备4060ti 16GB
- 以下为获得的102420483072分辨率的图片
![demo-result](https://github.com/Jaylen-Lee/image-demo/blob/main/3.png?raw=true)

View File

@ -58,8 +58,10 @@ class Script(scripts.Script):
with gr.Row(variant='compact') as tab_enable:
enabled = gr.Checkbox(label='Enable DemoFusion(Dont open with tilediffusion)', value=False, elem_id=uid('enabled'))
random_jitter = gr.Checkbox(label='Random Jitter', value = True, elem_id=uid('random-jitter'))
gaussian_filter = gr.Checkbox(label='Gaussian Filter', value=True, visible=False, elem_id=uid('gaussian'))
keep_input_size = gr.Checkbox(label='Keep input-image size', value=False,visible=is_img2img, elem_id=uid('keep-input-size'))
mixture_mode = gr.Checkbox(label='Mixture mode', value=False,elem_id=uid('mixture-mode'))
gaussian_filter = gr.Checkbox(label='Gaussian Filter', value=True, visible=False, elem_id=uid('gaussian'))
with gr.Row(variant='compact') as tab_param:
@ -75,11 +77,12 @@ class Script(scripts.Script):
with gr.Row(variant='compact'):
overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent window overlap', value=64, elem_id=uid('latent-tile-overlap'))
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Latent window batch size', value=4, elem_id=uid('latent-tile-batch-size'))
batch_size_g = gr.Slider(minimum=1, maximum=8, step=1, label='Global window batch size', value=4, elem_id=uid('Global-tile-batch-size'))
with gr.Row(variant='compact', visible=True) as tab_c:
c1 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 1', value=3, elem_id=f'C1-{tab}')
c2 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 2', value=1, elem_id=f'C2-{tab}')
c3 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 3', value=1, elem_id=f'C3-{tab}')
sigma = gr.Slider(minimum=0, maximum=1, step=0.01, label='Sigma', value=0.5, elem_id=f'Sigma-{tab}')
sigma = gr.Slider(minimum=0, maximum=2, step=0.01, label='Sigma', value=0.6, elem_id=f'Sigma-{tab}')
with gr.Group() as tab_denoise:
strength = gr.Slider(minimum=0, maximum=1, step=0.01, value = 0.85,label='Denoising Strength for Substage',visible=not is_img2img, elem_id=f'strength-{tab}')
with gr.Row(variant='compact') as tab_upscale:
@ -106,7 +109,7 @@ class Script(scripts.Script):
noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel,
control_tensor_cpu,
random_jitter,
c1,c2,c3,gaussian_filter,strength,sigma
c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode
]
@ -118,12 +121,14 @@ class Script(scripts.Script):
noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int,
control_tensor_cpu: bool,
random_jitter:bool,
c1,c2,c3,gaussian_filter,strength,sigma
c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode
):
# unhijack & unhook, in case it broke at last time
self.reset()
p.mixture = mixture_mode
if not mixture_mode:
sigma = sigma/2
if not enabled: return
''' upscale '''
@ -162,6 +167,7 @@ class Script(scripts.Script):
info['Window Size'] = window_size
info['Tile Overlap'] = overlap
info['Tile batch size'] = tile_batch_size
info["Global batch size"] = batch_size_g
if is_img2img:
info['Upscale factor'] = scale_factor
@ -199,19 +205,19 @@ class Script(scripts.Script):
sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack(
name, model, p, Method_2(method), control_tensor_cpu,window_size, noise_inverse, noise_inverse_steps, noise_inverse_retouch,
noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter
noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter,batch_size_g
)
p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack(
conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img,
window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma)
window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g)
processing.create_infotext_ori = processing.create_infotext
p.width_list = [p.height]
p.height_list = [p.height]
processing.create_infotext = create_infotext_hijack
## end
@ -255,7 +261,7 @@ class Script(scripts.Script):
''' ↓↓↓ inner API hijack ↓↓↓ '''
@torch.no_grad()
def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma):
def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g):
################################################## Phase Initialization ######################################################
if not image_ori:
@ -303,18 +309,19 @@ class Script(scripts.Script):
self.delegate.w = int(p.current_width / opt_f)
self.delegate.h = int(p.current_height / opt_f)
if current_scale_num>1:
self.delegate.get_views(overlap, tile_batch_size)
self.delegate.get_views(overlap, tile_batch_size,batch_size_g)
info = ', '.join([
# f"{method.value} hooked into {name!r} sampler",
f"Tile size: {self.delegate.window_size}",
f"Tile count: {self.delegate.num_tiles}",
f"Batch size: {self.delegate.tile_bs}",
f"Tile batches: {len(self.delegate.batched_bboxes)}",
])
info = ', '.join([
# f"{method.value} hooked into {name!r} sampler",
f"Tile size: {self.delegate.window_size}",
f"Tile count: {self.delegate.num_tiles}",
f"Batch size: {self.delegate.tile_bs}",
f"Tile batches: {len(self.delegate.batched_bboxes)}",
f"Global batch size: {self.delegate.global_tile_bs}",
f"Global batches: {len(self.delegate.global_batched_bboxes)}",
])
print(info)
print(info)
noise = p.rng.next()
if hasattr(p,'initial_noise_multiplier'):
@ -343,7 +350,7 @@ class Script(scripts.Script):
#########################################################################################################################################
return res
@staticmethod
def callback_hijack(self_sampler,d,p):
p.current_step = d['i']
@ -358,7 +365,7 @@ class Script(scripts.Script):
def create_sampler_hijack(
self, name: str, model: LatentDiffusion, p: Processing, method: Method_2, control_tensor_cpu:bool,window_size, noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float,
noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool
noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool,batch_size_g:int
):
if self.delegate is not None:
# samplers are stateless, we reuse it if possible
@ -394,7 +401,7 @@ class Script(scripts.Script):
set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch)
delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel)
delegate.get_views(overlap,tile_batch_size)
# delegate.get_views(overlap,tile_batch_size,batch_size_g)
if self.controlnet_script:
delegate.init_controlnet(self.controlnet_script, control_tensor_cpu)
if self.stablesr_script:
@ -406,20 +413,13 @@ class Script(scripts.Script):
self.delegate = delegate
info = ', '.join([
f"{method.value} hooked into {name!r} sampler",
f"Tile size: {delegate.window_size}",
f"Tile count: {delegate.num_tiles}",
f"Batch size: {delegate.tile_bs}",
f"Tile batches: {len(delegate.batched_bboxes)}",
])
exts = [
"ContrlNet" if self.controlnet_script else None,
"StableSR" if self.stablesr_script else None,
]
ext_info = ', '.join([e for e in exts if e])
if ext_info: ext_info = f' (ext: {ext_info})'
print(info + ext_info)
print(ext_info)
return delegate.sampler_raw

View File

@ -57,7 +57,7 @@ class DemoFusion(AbstractDiffusion):
shape = [n] + [1] * r_dims # [N, 1, ...]
return x.repeat(shape)
def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict:
def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict:
''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object '''
# n_repeat
n_rep = len(bboxes)
@ -67,9 +67,16 @@ class DemoFusion(AbstractDiffusion):
# img cond
icond = self.get_icond(cond_in)
if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W]
icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0)
if mode == 0:
if self.p.random_jitter:
jitter_range = self.jitter_range
icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0)
icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0)
else:
icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0)
else: # txt2img, [B=1, C=5, H=1, W=1]
icond = self.repeat_tensor(icond, n_rep)
# vec cond (SDXL)
vcond = self.get_vcond(cond_in) # [B=1, D]
if vcond is not None:
@ -89,7 +96,7 @@ class DemoFusion(AbstractDiffusion):
bbox = (x, y)
bbox_list.append(bbox)
return bbox_list
return bbox_list+bbox_list if self.p.mixture else bbox_list
def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]:
cols = math.ceil((w_l - overlap) / (tile_w - overlap))
@ -131,7 +138,7 @@ class DemoFusion(AbstractDiffusion):
return bbox_list, None
@grid_bbox
def get_views(self, overlap:int, tile_bs:int):
def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int):
self.enable_grid_bbox = True
self.tile_w = self.window_size
self.tile_h = self.window_size
@ -150,7 +157,7 @@ class DemoFusion(AbstractDiffusion):
global_bboxes = self.global_split_bboxes()
self.global_num_tiles = len(global_bboxes)
self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs)
self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g)
self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches)
self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)]
@ -169,7 +176,7 @@ class DemoFusion(AbstractDiffusion):
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
return blurred_latents
''' ↓↓↓ kernel hijacks ↓↓↓ '''
@ -211,23 +218,23 @@ class DemoFusion(AbstractDiffusion):
@keep_signature
def sample_one_step(self, x_in, sigma, cond):
assert LatentDiffusion.apply_model
def repeat_func_1(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor:
def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor:
sigma_tile = self.repeat_tensor(sigma, len(bboxes))
cond_tile = self.repeat_cond_dict(cond, bboxes)
cond_tile = self.repeat_cond_dict(cond, bboxes,mode)
return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile)
def repeat_func_2(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]:
def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]:
n_rep = len(bboxes)
ts_tile = self.repeat_tensor(sigma, n_rep)
if isinstance(cond, dict): # FIXME: when will enter this branch?
cond_tile = self.repeat_cond_dict(cond, bboxes)
cond_tile = self.repeat_cond_dict(cond, bboxes,mode)
else:
cond_tile = self.repeat_tensor(cond, n_rep)
return self.sampler_forward(x_tile, ts_tile, cond=cond_tile)
def repeat_func_3(x_tile:Tensor, bboxes:List[CustomBBox]):
def repeat_func_3(x_tile:Tensor, bboxes,mode=0):
sigma_in_tile = sigma.repeat(len(bboxes))
cond_out = self.repeat_cond_dict(cond, bboxes)
cond_out = self.repeat_cond_dict(cond, bboxes,mode)
x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out)
return x_tile_out
@ -256,7 +263,7 @@ class DemoFusion(AbstractDiffusion):
x_local = self.x_buffer/self.weights
self.x_buffer = torch.zeros_like(self.x_buffer)
self.x_buffer = torch.zeros_like(self.x_buffer)
self.weights = torch.zeros_like(self.weights)
std_, mean_ = x_in.std(), x_in.mean()
@ -272,20 +279,43 @@ class DemoFusion(AbstractDiffusion):
jitter_range = self.jitter_range
end = x_global.shape[3]-jitter_range
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
for bbox in bboxes:
w,h = bbox
# self.x_out_list = []
# self.x_out_idx = -1
# self.flag = 1
x_global_i0 = self.sampler_forward(x_in_g[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num],sigma,cond = cond)
# self.flag = 0
x_global_i1 = self.sampler_forward(x_in[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num],sigma,cond = cond) #NOTE According to the original execution process, it would be very strange to use the predicted noise of gaussian latents to predict the denoised data in non Gaussian latents. Why?
self.x_buffer[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num] += (x_global_i0 + x_global_i1)/2
self.weights[:,:,h+jitter_range:end:self.p.current_scale_num,w+jitter_range:end:self.p.current_scale_num] += 1
current_num = 0
if self.p.mixture:
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
current_num += len(bboxes)
if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2):
res = len(bboxes) - (current_num - self.global_num_tiles//2)
x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx<res else x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0)
elif current_num > (self.global_num_tiles//2):
x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0)
else:
x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0)
x_global_i = repeat_func(x_in_i,bboxes,mode=1)
if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2):
for idx,bbox in enumerate(bboxes):
x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:]
elif current_num > (self.global_num_tiles//2):
for idx,bbox in enumerate(bboxes):
x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:]
else:
for idx,bbox in enumerate(bboxes):
x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:]
else:
for batch_id, bboxes in enumerate(self.global_batched_bboxes):
x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0)
x_global_i = repeat_func(x_in_i,bboxes,mode=1)
for idx,bbox in enumerate(bboxes):
x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:]
#NOTE According to the original execution process, it would be very strange to use the predicted noise of gaussian latents to predict the denoised data in non Gaussian latents. Why?
if self.p.mixture:
self.x_buffer +=x_global/2
else:
self.x_buffer += x_global
self.weights += 1
self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori
self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode
x_global = self.x_buffer/self.weights
c2 = self.cosine_factor**self.p.cosine_scale_2