fix for forge differential

main
hako-mikan 2025-02-06 00:02:45 +09:00
parent d2c8fce3c2
commit 92880a60bc
3 changed files with 13 additions and 11 deletions

View File

@ -379,7 +379,7 @@ def hook_forward(self, module):
db(self,f"tokens3 : {tl[0]*TOKENSCON}-{tl[1]*TOKENSCON}")
db(self,f"extra-tokens : {cnet_ext}")
userpp = self.pn and i == 0 and self.pfirst
userpp = pn and i == 0
negpip = negpipdealer(self.condi,pn) if "La" in self.calc else negpipdealer(i,pn)
@ -476,7 +476,6 @@ def hook_forward(self, module):
if self.count == limit:
self.pn = not self.pn
self.count = 0
self.pfirst = False
self.condi += 1
db(self,f"output : {ox.size()}")
return ox

View File

@ -97,19 +97,17 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
self.pn = self.pn_s
if self.only_r:
if self.only_r and not self.diff:
return
if "Pro" in self.mode: # in Prompt mode, make masks from sum of attension maps
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
self.pfirst = True
lim = 1 if self.is_sdxl else 3
if len(att.pmaskshw) > lim:
self.filters = []
for b in range(self.batch_size):
allmask = []
basemask = None
for t, th, bratio in zip(self.pe, self.th, self.bratios):
@ -199,7 +197,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
conds = c["c_crossattn"]
y = c["y"] if "y" in c else None
if not lactive:
if not lactive and not self.diff:
return p1(input_x, timestep, **c)
length = len(cond_or_uncond)
@ -231,6 +229,8 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
xt = params.x.clone()
areas = xt.shape[0] // batch - 1
#print(f"Denoised call back: calc {self.calc}, mode {self.mode}, rps {self.rps}, diff {self.diff}, filters {len(self.filters)} ")
if "La" in self.calc:
# x.shape = [batch_size, C, H // 8, W // 8]
@ -269,10 +269,10 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
if self.rps is not None and self.diff:
if self.rps.latent is None:
self.rps.latent = x.clone()
return
return x[orig_list] if p3 is not None else None
elif self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None:
self.rps.latent_hr = x.clone()
return
return x[orig_list] if p3 is not None else None
else:
for b in range(batch):
for a in range(areas) :
@ -287,7 +287,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
if self.rps is not None and self.diff:
if self.rps.latent is not None:
if self.rps.latent.shape[2:] != x.shape[2:]:
if self.rps.latent_hr is None: return
if self.rps.latent_hr is None: return x[orig_list] if p3 is not None else None
for b in range(batch):
for a in range(areas) :
fil = self.filters[a+1]
@ -299,7 +299,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
if self.step == 0 and self.in_hr:
if self.rps is not None and self.diff:
if self.rps.latent is not None:
if self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None: return
if self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None: return x[orig_list] if p3 is not None else None
for b in range(batch):
for a in range(areas) :
fil = self.filters[a+1]
@ -602,6 +602,7 @@ class LoRARegioner:
refresh(lora_lorader, lora_patches=lora_patches, offload_device=offload_device)
def set_region(self, region):
if self.u_llist == [{}]: return
self.mlist = self.u_llist[region]
if labug:
print(f"Set LoRA for Region {region}, u_count",self.u_count ,"u_count '%' divide", self.u_count % len(self.u_llist))

View File

@ -510,6 +510,7 @@ class Script(modules.scripts.Script):
if forge or reforge:
self.isvanilla = not self.isvanilla
self.pn = False
self.pn_s = False
if self.h % ATTNSCALE != 0 or self.w % ATTNSCALE != 0:
# Testing shows a round down occurs in model.
@ -690,7 +691,8 @@ def denoiserdealer(self, only_r):
if self.diff:
if not hasattr(self,"dd_callbacks"):
self.dd_callbacks = on_cfg_denoised(self.denoised_callback)
if forge or reforge:
shared.sd_model.forge_objects.unet.set_model_unet_function_wrapper(lambda apply, params: denoised_callback_s(apply, params, p3=self))
############################################################
##### prompts, tokens