fix for forge differential
parent
d2c8fce3c2
commit
92880a60bc
|
|
@ -379,7 +379,7 @@ def hook_forward(self, module):
|
|||
db(self,f"tokens3 : {tl[0]*TOKENSCON}-{tl[1]*TOKENSCON}")
|
||||
db(self,f"extra-tokens : {cnet_ext}")
|
||||
|
||||
userpp = self.pn and i == 0 and self.pfirst
|
||||
userpp = pn and i == 0
|
||||
|
||||
negpip = negpipdealer(self.condi,pn) if "La" in self.calc else negpipdealer(i,pn)
|
||||
|
||||
|
|
@ -476,7 +476,6 @@ def hook_forward(self, module):
|
|||
if self.count == limit:
|
||||
self.pn = not self.pn
|
||||
self.count = 0
|
||||
self.pfirst = False
|
||||
self.condi += 1
|
||||
db(self,f"output : {ox.size()}")
|
||||
return ox
|
||||
|
|
|
|||
|
|
@ -97,19 +97,17 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|||
|
||||
self.pn = self.pn_s
|
||||
|
||||
if self.only_r:
|
||||
if self.only_r and not self.diff:
|
||||
return
|
||||
|
||||
if "Pro" in self.mode: # in Prompt mode, make masks from sum of attension maps
|
||||
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
|
||||
self.pfirst = True
|
||||
|
||||
lim = 1 if self.is_sdxl else 3
|
||||
|
||||
if len(att.pmaskshw) > lim:
|
||||
self.filters = []
|
||||
for b in range(self.batch_size):
|
||||
|
||||
allmask = []
|
||||
basemask = None
|
||||
for t, th, bratio in zip(self.pe, self.th, self.bratios):
|
||||
|
|
@ -199,7 +197,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
|
|||
conds = c["c_crossattn"]
|
||||
y = c["y"] if "y" in c else None
|
||||
|
||||
if not lactive:
|
||||
if not lactive and not self.diff:
|
||||
return p1(input_x, timestep, **c)
|
||||
|
||||
length = len(cond_or_uncond)
|
||||
|
|
@ -231,6 +229,8 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
|
|||
xt = params.x.clone()
|
||||
areas = xt.shape[0] // batch - 1
|
||||
|
||||
#print(f"Denoised call back: calc {self.calc}, mode {self.mode}, rps {self.rps}, diff {self.diff}, filters {len(self.filters)} ")
|
||||
|
||||
if "La" in self.calc:
|
||||
# x.shape = [batch_size, C, H // 8, W // 8]
|
||||
|
||||
|
|
@ -269,10 +269,10 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
|
|||
if self.rps is not None and self.diff:
|
||||
if self.rps.latent is None:
|
||||
self.rps.latent = x.clone()
|
||||
return
|
||||
return x[orig_list] if p3 is not None else None
|
||||
elif self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None:
|
||||
self.rps.latent_hr = x.clone()
|
||||
return
|
||||
return x[orig_list] if p3 is not None else None
|
||||
else:
|
||||
for b in range(batch):
|
||||
for a in range(areas) :
|
||||
|
|
@ -287,7 +287,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
|
|||
if self.rps is not None and self.diff:
|
||||
if self.rps.latent is not None:
|
||||
if self.rps.latent.shape[2:] != x.shape[2:]:
|
||||
if self.rps.latent_hr is None: return
|
||||
if self.rps.latent_hr is None: return x[orig_list] if p3 is not None else None
|
||||
for b in range(batch):
|
||||
for a in range(areas) :
|
||||
fil = self.filters[a+1]
|
||||
|
|
@ -299,7 +299,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
|
|||
if self.step == 0 and self.in_hr:
|
||||
if self.rps is not None and self.diff:
|
||||
if self.rps.latent is not None:
|
||||
if self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None: return
|
||||
if self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None: return x[orig_list] if p3 is not None else None
|
||||
for b in range(batch):
|
||||
for a in range(areas) :
|
||||
fil = self.filters[a+1]
|
||||
|
|
@ -602,6 +602,7 @@ class LoRARegioner:
|
|||
refresh(lora_lorader, lora_patches=lora_patches, offload_device=offload_device)
|
||||
|
||||
def set_region(self, region):
|
||||
if self.u_llist == [{}]: return
|
||||
self.mlist = self.u_llist[region]
|
||||
if labug:
|
||||
print(f"Set LoRA for Region {region}, u_count",self.u_count ,"u_count '%' divide", self.u_count % len(self.u_llist))
|
||||
|
|
|
|||
|
|
@ -510,6 +510,7 @@ class Script(modules.scripts.Script):
|
|||
if forge or reforge:
|
||||
self.isvanilla = not self.isvanilla
|
||||
self.pn = False
|
||||
self.pn_s = False
|
||||
|
||||
if self.h % ATTNSCALE != 0 or self.w % ATTNSCALE != 0:
|
||||
# Testing shows a round down occurs in model.
|
||||
|
|
@ -690,7 +691,8 @@ def denoiserdealer(self, only_r):
|
|||
if self.diff:
|
||||
if not hasattr(self,"dd_callbacks"):
|
||||
self.dd_callbacks = on_cfg_denoised(self.denoised_callback)
|
||||
|
||||
if forge or reforge:
|
||||
shared.sd_model.forge_objects.unet.set_model_unet_function_wrapper(lambda apply, params: denoised_callback_s(apply, params, p3=self))
|
||||
|
||||
############################################################
|
||||
##### prompts, tokens
|
||||
|
|
|
|||
Loading…
Reference in New Issue