main
parent
37fdf30cf2
commit
6c039b311f
|
|
@ -31,7 +31,7 @@ def default(val, d):
|
|||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,tokens=[],width = 64,height = 64,step = 0, isxl = False, negpip = None, inhr = None):
|
||||
def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,tokens=[],width = 64,height = 64,step = 0, is_sdxl = False, negpip = None, inhr = None):
|
||||
|
||||
# Forward.
|
||||
|
||||
|
|
@ -84,25 +84,26 @@ def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,t
|
|||
if inhr and not hiresfinished: hiresscaler(height,width,attn,h)
|
||||
|
||||
if userpp and step > 0:
|
||||
for b in range(attn.shape[0] // 8):
|
||||
for b in range(attn.shape[0] // h):
|
||||
if pmaskshw == []:
|
||||
pmaskshw = [(height,width)]
|
||||
elif (height,width) not in pmaskshw:
|
||||
pmaskshw.append((height,width))
|
||||
|
||||
for t in tokens:
|
||||
power = 4 if isxl else 1.2
|
||||
add = attn[8*b:8*(b+1),:,t[0]:t[0]+len(t)]**power
|
||||
power = 4 if is_sdxl else 1.2
|
||||
add = attn[h*b:h*(b+1),:,t[0]:t[0]+len(t)]**power
|
||||
add = torch.sum(add,dim = 2)
|
||||
t = f"{t}-{b}"
|
||||
if t not in pmasks:
|
||||
pmasks[t] = add
|
||||
else:
|
||||
if pmasks[t].shape[1] != add.shape[1]:
|
||||
add = add.view(8,height,width)
|
||||
add = add.view(h,height,width)
|
||||
add = F.resize(add,pmaskshw[0])
|
||||
if add.numel() != pmasks[t].numel():
|
||||
add = add.view(pmasks[t].shape[0], 2, add.shape[1], add.shape[2]).sum(dim=1) / 2
|
||||
add = add.reshape_as(pmasks[t])
|
||||
|
||||
pmasks[t] = pmasks[t] + add
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
|
|
@ -112,10 +113,11 @@ def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,t
|
|||
return out
|
||||
|
||||
def hook_forwards(self, p, remove=False):
|
||||
self.need_hook = not remove
|
||||
if forge:
|
||||
self.handle = hook_forwards_x(self, p.sd_model.forge_objects.unet.model, remove)
|
||||
hook_forwards_x(self, p.sd_model.forge_objects.unet.model, remove)
|
||||
else:
|
||||
self.handle = hook_forwards_x(self, p.sd_model.model.diffusion_model, remove)
|
||||
hook_forwards_x(self, p.sd_model.model.diffusion_model, remove)
|
||||
|
||||
def hook_forwards_x(self, root_module: torch.nn.Module, remove=False):
|
||||
self.hooked = True if not remove else False
|
||||
|
|
@ -190,7 +192,7 @@ def hook_forward(self, module):
|
|||
|
||||
i = i + 1
|
||||
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp =True,step = self.step, isxl = self.isxl, negpip = negpip)
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp =True,step = self.step, is_sdxl = self.is_sdxl, negpip = negpip)
|
||||
|
||||
if len(self.nt) == 1 and not pn:
|
||||
db(self,"return out for NP")
|
||||
|
|
@ -221,7 +223,7 @@ def hook_forward(self, module):
|
|||
# if i >= contexts.size()[1]:
|
||||
# indlast = True
|
||||
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp = self.pn, step = self.step, isxl = self.isxl,negpip = negpip)
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp = self.pn, step = self.step, is_sdxl = self.is_sdxl,negpip = negpip)
|
||||
db(self,f" dcell.breaks : {dcell.breaks}, dcell.ed : {dcell.ed}, dcell.st : {dcell.st}")
|
||||
if len(self.nt) == 1 and not pn:
|
||||
db(self,"return out for NP")
|
||||
|
|
@ -307,7 +309,7 @@ def hook_forward(self, module):
|
|||
negpip = negpipdealer(i,pn)
|
||||
|
||||
i = i + 1
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla, isxl = self.isxl, negpip = negpip)
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla, is_sdxl = self.is_sdxl, negpip = negpip)
|
||||
|
||||
if len(self.nt) == 1 and not pn:
|
||||
db(self,"return out for NP")
|
||||
|
|
@ -342,7 +344,7 @@ def hook_forward(self, module):
|
|||
i = i + 1
|
||||
# if i >= contexts.size()[1]:
|
||||
# indlast = True
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla, isxl = self.isxl)
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla, is_sdxl = self.is_sdxl)
|
||||
if len(self.nt) == 1 and not pn:
|
||||
db(self,"return out for NP")
|
||||
return out
|
||||
|
|
@ -382,7 +384,7 @@ def hook_forward(self, module):
|
|||
negpip = negpipdealer(self.condi,pn) if "La" in self.calc else negpipdealer(i,pn)
|
||||
|
||||
out = main_forward(module, x, context, mask, divide, self.isvanilla, userpp = userpp, width = dsw, height = dsh,
|
||||
tokens = self.pe, step = self.step, isxl = self.isxl, negpip = negpip, inhr = self.in_hr)
|
||||
tokens = self.pe, step = self.step, is_sdxl = self.is_sdxl, negpip = negpip, inhr = self.in_hr)
|
||||
|
||||
if (len(self.nt) == 1 and not pn) or ("Pro" in self.mode and "La" in self.calc):
|
||||
db(self,"return out for NP or Latent")
|
||||
|
|
@ -469,7 +471,7 @@ def hook_forward(self, module):
|
|||
|
||||
self.count += 1
|
||||
|
||||
limit = 70 if self.isxl else 16
|
||||
limit = 70 if self.is_sdxl else 16
|
||||
|
||||
if self.count == limit:
|
||||
self.pn = not self.pn
|
||||
|
|
@ -548,7 +550,7 @@ def reset_pmasks(self): # init parameters in every batch
|
|||
|
||||
def savepmasks(self,processed):
|
||||
for mask ,th in zip(pmasks.values(),self.th):
|
||||
img, _ , _= makepmask(mask, self.h, self.w,th, self.step)
|
||||
img, _ , _= makepmask(mask, self.h, self.w,th, self.step, self.total_step, self.is_sdxl)
|
||||
processed.images.append(img)
|
||||
return processed
|
||||
|
||||
|
|
@ -580,11 +582,11 @@ def hiresmask(masks,oh,ow,nh,nw,head,at = None,i = None):
|
|||
else:
|
||||
masks[key][i] = mask
|
||||
|
||||
def makepmask(mask, h, w, th, step, bratio = 1): # make masks from attention cache return [for preview, for attention, for Latent]
|
||||
th = th - step * 0.005
|
||||
def makepmask(mask, h, w, th, step, total_step, is_sdxl, bratio = 1): # make masks from attention cache return [for preview, for attention, for Latent]
|
||||
th = th - step * 0.005
|
||||
bratio = 1 - bratio
|
||||
mask = torch.mean(mask,dim=0)
|
||||
mask = mask / mask.max().item()
|
||||
mask = mask / mask.max().item() * 4 if is_sdxl else 1
|
||||
mask = torch.where(mask > th ,1,0)
|
||||
mask = mask.float()
|
||||
mask = mask.view(1,pmaskshw[0][0],pmaskshw[0][1])
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|||
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
|
||||
self.pfirst = True
|
||||
|
||||
lim = 1 if self.isxl else 3
|
||||
lim = 1 if self.is_sdxl else 3
|
||||
|
||||
if len(att.pmaskshw) > lim:
|
||||
self.filters = []
|
||||
|
|
@ -109,7 +109,7 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|||
basemask = None
|
||||
for t, th, bratio in zip(self.pe, self.th, self.bratios):
|
||||
key = f"{t}-{b}"
|
||||
_, _, mask = att.makepmask(att.pmasks[key], params.x.shape[2], params.x.shape[3], th, self.step, bratio = bratio)
|
||||
_, _, mask = att.makepmask(att.pmasks[key], params.x.shape[2], params.x.shape[3], th, self.step,self.total_step, self.is_sdxl,bratio = bratio)
|
||||
mask = mask.repeat(params.x.shape[1],1,1)
|
||||
basemask = 1 - mask if basemask is None else basemask - mask
|
||||
if self.ex:
|
||||
|
|
@ -132,7 +132,7 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|||
masks = None
|
||||
for b in range(self.batch_size):
|
||||
key = f"{t}-{b}"
|
||||
_, mask, _ = att.makepmask(att.pmasks[key], hw[0], hw[1], th, self.step, bratio = bratio)
|
||||
_, mask, _ = att.makepmask(att.pmasks[key], hw[0], hw[1], th, self.step,self.total_step, self.is_sdxl,bratio = bratio)
|
||||
mask = mask.unsqueeze(0).unsqueeze(-1)
|
||||
masks = mask if b ==0 else torch.cat((masks,mask),dim=0)
|
||||
allmask.append(mask)
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ def compress_components(l):
|
|||
|
||||
class Script(modules.scripts.Script):
|
||||
def __init__(self,active = False,mode = "Matrix",calc = "Attention",h = 0, w =0, debug = False, debug2 = False, usebase = False,
|
||||
usecom = False, usencom = False, batch = 1,isxl = False, lstop=0, lstop_hr=0, diff = None):
|
||||
usecom = False, usencom = False, batch = 1,lstop=0, lstop_hr=0, diff = None):
|
||||
self.active = active
|
||||
if mode == "Columns": mode = "Horizontal"
|
||||
if mode == "Rows": mode = "Vertical"
|
||||
|
|
@ -216,8 +216,13 @@ class Script(modules.scripts.Script):
|
|||
self.usecom = usecom
|
||||
self.usencom = usencom
|
||||
self.batch_size = batch
|
||||
self.isxl = isxl
|
||||
|
||||
model = shared.sd_model
|
||||
self.is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False)
|
||||
self.is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False)
|
||||
self.is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False)
|
||||
self.is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False)
|
||||
|
||||
self.aratios = []
|
||||
self.bratios = []
|
||||
self.divide = 0
|
||||
|
|
@ -243,6 +248,8 @@ class Script(modules.scripts.Script):
|
|||
#for prompt region
|
||||
self.pe = []
|
||||
self.step = 0
|
||||
|
||||
self.need_hook = False
|
||||
|
||||
#for Differential
|
||||
self.diff = diff
|
||||
|
|
@ -253,7 +260,7 @@ class Script(modules.scripts.Script):
|
|||
self.condi = 0
|
||||
|
||||
self.used_prompt = ""
|
||||
self.logprops = ["active","mode","usebase","usecom","usencom","batch_size","isxl","h","w","aratios",
|
||||
self.logprops = ["active","mode","usebase","usecom","usencom","batch_size","is_sdxl","h","w","aratios",
|
||||
"divide","count","eq","pn","hr","pe","step","diff","used_prompt"]
|
||||
self.log = {}
|
||||
|
||||
|
|
@ -490,7 +497,7 @@ class Script(modules.scripts.Script):
|
|||
if flipper:aratios = changecs(aratios)
|
||||
|
||||
self.__init__(active, tabs2mode(rp_selected_tab, mmode, xmode, pmode) ,calcmode ,p.height, p.width, debug, debug2,
|
||||
usebase, usecom, usencom, p.batch_size, hasattr(shared.sd_model,"conditioner"),lstop, lstop_hr, diff = diff)
|
||||
usebase, usecom, usencom, p.batch_size, lstop, lstop_hr, diff = diff)
|
||||
|
||||
self.all_prompts = p.all_prompts.copy()
|
||||
self.all_negative_prompts = p.all_negative_prompts.copy()
|
||||
|
|
@ -536,14 +543,14 @@ class Script(modules.scripts.Script):
|
|||
|
||||
##### calcmode
|
||||
if "Att" in calcmode:
|
||||
self.handle = hook_forwards(self, p)
|
||||
hook_forwards(self, p)
|
||||
if hasattr(shared.opts,"batch_cond_uncond"):
|
||||
shared.opts.batch_cond_uncond = orig_batch_cond_uncond
|
||||
else:
|
||||
shared.batch_cond_uncond = orig_batch_cond_uncond
|
||||
unloadlorafowards(p)
|
||||
else:
|
||||
self.handle = hook_forwards(self, p, remove = True)
|
||||
hook_forwards(self, p, remove = True)
|
||||
setuploras(self)
|
||||
# SBM It is vital to use local activation because callback registration is permanent,
|
||||
# and there are multiple script instances (txt2img / img2img).
|
||||
|
|
@ -551,7 +558,7 @@ class Script(modules.scripts.Script):
|
|||
elif "Pro" in self.mode: #Prompt mode use both calcmode
|
||||
self.ex = "Ex" in self.mode
|
||||
if not usebase : bratios = "0"
|
||||
self.handle = hook_forwards(self, p)
|
||||
hook_forwards(self, p)
|
||||
denoiserdealer(self, p)
|
||||
|
||||
if OPTCOUT in options: commentouter(p)
|
||||
|
|
@ -578,6 +585,10 @@ class Script(modules.scripts.Script):
|
|||
self.current_prompts = kwargs["prompts"].copy()
|
||||
p.disable_extra_networks = False
|
||||
|
||||
def process_before_every_sampling(self, p, *args, **kwargs):
|
||||
if self.active and forge and self.need_hook:
|
||||
hook_forwards(self, p)
|
||||
|
||||
def before_hr(self, p, active, _, rp_selected_tab, mmode, xmode, pmode, aratios, bratios,
|
||||
usebase, usecom, usencom, calcmode,nchangeand, lnter, lnur, threshold, polymask,lstop, lstop_hr, flipper):
|
||||
if self.active:
|
||||
|
|
@ -644,10 +655,9 @@ class Script(modules.scripts.Script):
|
|||
denoised_callback_s(self, params)
|
||||
|
||||
def unloader(self,p):
|
||||
if hasattr(self,"handle"):
|
||||
if self.hooked:
|
||||
#print("unloaded")
|
||||
hook_forwards(self, p, remove=True)
|
||||
del self.handle
|
||||
|
||||
self.__init__()
|
||||
|
||||
|
|
@ -783,7 +793,7 @@ def tokendealer(self, p):
|
|||
tokenizer = p.sd_model.text_processing_engine.tokenize_line
|
||||
self.flux = flux = "flux" in str(type(p.sd_model.forge_objects.unet.model.diffusion_model))
|
||||
else:
|
||||
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.isxl else shared.sd_model.cond_stage_model.tokenize_line
|
||||
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.is_sdxl else shared.sd_model.cond_stage_model.tokenize_line
|
||||
self.flux = flux = False
|
||||
|
||||
for pp in ppl:
|
||||
|
|
@ -1160,7 +1170,7 @@ def debugall(self):
|
|||
print(f"tokens : {self.ppt},{self.pnt},{self.pt},{self.nt}")
|
||||
print(f"ratios : {self.aratios}\n")
|
||||
print(f"prompt : {self.pe}")
|
||||
print(f"env : before15:{self.isbefore15},isxl:{self.isxl}")
|
||||
print(f"env : before15:{self.isbefore15},isxl:{self.is_sdxl}")
|
||||
print(f"loras{self.log}")
|
||||
|
||||
def bckeydealer(self, p):
|
||||
|
|
|
|||
Loading…
Reference in New Issue