hako-mikan 2025-01-30 19:06:44 +09:00
parent 37fdf30cf2
commit 6c039b311f
3 changed files with 44 additions and 32 deletions

View File

@ -31,7 +31,7 @@ def default(val, d):
return val
return d() if isfunction(d) else d
def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,tokens=[],width = 64,height = 64,step = 0, isxl = False, negpip = None, inhr = None):
def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,tokens=[],width = 64,height = 64,step = 0, is_sdxl = False, negpip = None, inhr = None):
# Forward.
@ -84,25 +84,26 @@ def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,t
if inhr and not hiresfinished: hiresscaler(height,width,attn,h)
if userpp and step > 0:
for b in range(attn.shape[0] // 8):
for b in range(attn.shape[0] // h):
if pmaskshw == []:
pmaskshw = [(height,width)]
elif (height,width) not in pmaskshw:
pmaskshw.append((height,width))
for t in tokens:
power = 4 if isxl else 1.2
add = attn[8*b:8*(b+1),:,t[0]:t[0]+len(t)]**power
power = 4 if is_sdxl else 1.2
add = attn[h*b:h*(b+1),:,t[0]:t[0]+len(t)]**power
add = torch.sum(add,dim = 2)
t = f"{t}-{b}"
if t not in pmasks:
pmasks[t] = add
else:
if pmasks[t].shape[1] != add.shape[1]:
add = add.view(8,height,width)
add = add.view(h,height,width)
add = F.resize(add,pmaskshw[0])
if add.numel() != pmasks[t].numel():
add = add.view(pmasks[t].shape[0], 2, add.shape[1], add.shape[2]).sum(dim=1) / 2
add = add.reshape_as(pmasks[t])
pmasks[t] = pmasks[t] + add
out = einsum('b i j, b j d -> b i d', attn, v)
@ -112,10 +113,11 @@ def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,t
return out
def hook_forwards(self, p, remove=False):
self.need_hook = not remove
if forge:
self.handle = hook_forwards_x(self, p.sd_model.forge_objects.unet.model, remove)
hook_forwards_x(self, p.sd_model.forge_objects.unet.model, remove)
else:
self.handle = hook_forwards_x(self, p.sd_model.model.diffusion_model, remove)
hook_forwards_x(self, p.sd_model.model.diffusion_model, remove)
def hook_forwards_x(self, root_module: torch.nn.Module, remove=False):
self.hooked = True if not remove else False
@ -190,7 +192,7 @@ def hook_forward(self, module):
i = i + 1
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp =True,step = self.step, isxl = self.isxl, negpip = negpip)
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp =True,step = self.step, is_sdxl = self.is_sdxl, negpip = negpip)
if len(self.nt) == 1 and not pn:
db(self,"return out for NP")
@ -221,7 +223,7 @@ def hook_forward(self, module):
# if i >= contexts.size()[1]:
# indlast = True
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp = self.pn, step = self.step, isxl = self.isxl,negpip = negpip)
out = main_forward(module, x, context, mask, divide, self.isvanilla,userpp = self.pn, step = self.step, is_sdxl = self.is_sdxl,negpip = negpip)
db(self,f" dcell.breaks : {dcell.breaks}, dcell.ed : {dcell.ed}, dcell.st : {dcell.st}")
if len(self.nt) == 1 and not pn:
db(self,"return out for NP")
@ -307,7 +309,7 @@ def hook_forward(self, module):
negpip = negpipdealer(i,pn)
i = i + 1
out = main_forward(module, x, context, mask, divide, self.isvanilla, isxl = self.isxl, negpip = negpip)
out = main_forward(module, x, context, mask, divide, self.isvanilla, is_sdxl = self.is_sdxl, negpip = negpip)
if len(self.nt) == 1 and not pn:
db(self,"return out for NP")
@ -342,7 +344,7 @@ def hook_forward(self, module):
i = i + 1
# if i >= contexts.size()[1]:
# indlast = True
out = main_forward(module, x, context, mask, divide, self.isvanilla, isxl = self.isxl)
out = main_forward(module, x, context, mask, divide, self.isvanilla, is_sdxl = self.is_sdxl)
if len(self.nt) == 1 and not pn:
db(self,"return out for NP")
return out
@ -382,7 +384,7 @@ def hook_forward(self, module):
negpip = negpipdealer(self.condi,pn) if "La" in self.calc else negpipdealer(i,pn)
out = main_forward(module, x, context, mask, divide, self.isvanilla, userpp = userpp, width = dsw, height = dsh,
tokens = self.pe, step = self.step, isxl = self.isxl, negpip = negpip, inhr = self.in_hr)
tokens = self.pe, step = self.step, is_sdxl = self.is_sdxl, negpip = negpip, inhr = self.in_hr)
if (len(self.nt) == 1 and not pn) or ("Pro" in self.mode and "La" in self.calc):
db(self,"return out for NP or Latent")
@ -469,7 +471,7 @@ def hook_forward(self, module):
self.count += 1
limit = 70 if self.isxl else 16
limit = 70 if self.is_sdxl else 16
if self.count == limit:
self.pn = not self.pn
@ -548,7 +550,7 @@ def reset_pmasks(self): # init parameters in every batch
def savepmasks(self,processed):
for mask ,th in zip(pmasks.values(),self.th):
img, _ , _= makepmask(mask, self.h, self.w,th, self.step)
img, _ , _= makepmask(mask, self.h, self.w,th, self.step, self.total_step, self.is_sdxl)
processed.images.append(img)
return processed
@ -580,11 +582,11 @@ def hiresmask(masks,oh,ow,nh,nw,head,at = None,i = None):
else:
masks[key][i] = mask
def makepmask(mask, h, w, th, step, bratio = 1): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005
def makepmask(mask, h, w, th, step, total_step, is_sdxl, bratio = 1): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005
bratio = 1 - bratio
mask = torch.mean(mask,dim=0)
mask = mask / mask.max().item()
mask = mask / mask.max().item() * 4 if is_sdxl else 1
mask = torch.where(mask > th ,1,0)
mask = mask.float()
mask = mask.view(1,pmaskshw[0][0],pmaskshw[0][1])

View File

@ -99,7 +99,7 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
self.pfirst = True
lim = 1 if self.isxl else 3
lim = 1 if self.is_sdxl else 3
if len(att.pmaskshw) > lim:
self.filters = []
@ -109,7 +109,7 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
basemask = None
for t, th, bratio in zip(self.pe, self.th, self.bratios):
key = f"{t}-{b}"
_, _, mask = att.makepmask(att.pmasks[key], params.x.shape[2], params.x.shape[3], th, self.step, bratio = bratio)
_, _, mask = att.makepmask(att.pmasks[key], params.x.shape[2], params.x.shape[3], th, self.step,self.total_step, self.is_sdxl,bratio = bratio)
mask = mask.repeat(params.x.shape[1],1,1)
basemask = 1 - mask if basemask is None else basemask - mask
if self.ex:
@ -132,7 +132,7 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
masks = None
for b in range(self.batch_size):
key = f"{t}-{b}"
_, mask, _ = att.makepmask(att.pmasks[key], hw[0], hw[1], th, self.step, bratio = bratio)
_, mask, _ = att.makepmask(att.pmasks[key], hw[0], hw[1], th, self.step,self.total_step, self.is_sdxl,bratio = bratio)
mask = mask.unsqueeze(0).unsqueeze(-1)
masks = mask if b ==0 else torch.cat((masks,mask),dim=0)
allmask.append(mask)

View File

@ -202,7 +202,7 @@ def compress_components(l):
class Script(modules.scripts.Script):
def __init__(self,active = False,mode = "Matrix",calc = "Attention",h = 0, w =0, debug = False, debug2 = False, usebase = False,
usecom = False, usencom = False, batch = 1,isxl = False, lstop=0, lstop_hr=0, diff = None):
usecom = False, usencom = False, batch = 1,lstop=0, lstop_hr=0, diff = None):
self.active = active
if mode == "Columns": mode = "Horizontal"
if mode == "Rows": mode = "Vertical"
@ -216,8 +216,13 @@ class Script(modules.scripts.Script):
self.usecom = usecom
self.usencom = usencom
self.batch_size = batch
self.isxl = isxl
model = shared.sd_model
self.is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False)
self.is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False)
self.is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False)
self.is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False)
self.aratios = []
self.bratios = []
self.divide = 0
@ -243,6 +248,8 @@ class Script(modules.scripts.Script):
#for prompt region
self.pe = []
self.step = 0
self.need_hook = False
#for Differential
self.diff = diff
@ -253,7 +260,7 @@ class Script(modules.scripts.Script):
self.condi = 0
self.used_prompt = ""
self.logprops = ["active","mode","usebase","usecom","usencom","batch_size","isxl","h","w","aratios",
self.logprops = ["active","mode","usebase","usecom","usencom","batch_size","is_sdxl","h","w","aratios",
"divide","count","eq","pn","hr","pe","step","diff","used_prompt"]
self.log = {}
@ -490,7 +497,7 @@ class Script(modules.scripts.Script):
if flipper:aratios = changecs(aratios)
self.__init__(active, tabs2mode(rp_selected_tab, mmode, xmode, pmode) ,calcmode ,p.height, p.width, debug, debug2,
usebase, usecom, usencom, p.batch_size, hasattr(shared.sd_model,"conditioner"),lstop, lstop_hr, diff = diff)
usebase, usecom, usencom, p.batch_size, lstop, lstop_hr, diff = diff)
self.all_prompts = p.all_prompts.copy()
self.all_negative_prompts = p.all_negative_prompts.copy()
@ -536,14 +543,14 @@ class Script(modules.scripts.Script):
##### calcmode
if "Att" in calcmode:
self.handle = hook_forwards(self, p)
hook_forwards(self, p)
if hasattr(shared.opts,"batch_cond_uncond"):
shared.opts.batch_cond_uncond = orig_batch_cond_uncond
else:
shared.batch_cond_uncond = orig_batch_cond_uncond
unloadlorafowards(p)
else:
self.handle = hook_forwards(self, p, remove = True)
hook_forwards(self, p, remove = True)
setuploras(self)
# SBM It is vital to use local activation because callback registration is permanent,
# and there are multiple script instances (txt2img / img2img).
@ -551,7 +558,7 @@ class Script(modules.scripts.Script):
elif "Pro" in self.mode: #Prompt mode use both calcmode
self.ex = "Ex" in self.mode
if not usebase : bratios = "0"
self.handle = hook_forwards(self, p)
hook_forwards(self, p)
denoiserdealer(self, p)
if OPTCOUT in options: commentouter(p)
@ -578,6 +585,10 @@ class Script(modules.scripts.Script):
self.current_prompts = kwargs["prompts"].copy()
p.disable_extra_networks = False
def process_before_every_sampling(self, p, *args, **kwargs):
if self.active and forge and self.need_hook:
hook_forwards(self, p)
def before_hr(self, p, active, _, rp_selected_tab, mmode, xmode, pmode, aratios, bratios,
usebase, usecom, usencom, calcmode,nchangeand, lnter, lnur, threshold, polymask,lstop, lstop_hr, flipper):
if self.active:
@ -644,10 +655,9 @@ class Script(modules.scripts.Script):
denoised_callback_s(self, params)
def unloader(self,p):
if hasattr(self,"handle"):
if self.hooked:
#print("unloaded")
hook_forwards(self, p, remove=True)
del self.handle
self.__init__()
@ -783,7 +793,7 @@ def tokendealer(self, p):
tokenizer = p.sd_model.text_processing_engine.tokenize_line
self.flux = flux = "flux" in str(type(p.sd_model.forge_objects.unet.model.diffusion_model))
else:
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.isxl else shared.sd_model.cond_stage_model.tokenize_line
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.is_sdxl else shared.sd_model.cond_stage_model.tokenize_line
self.flux = flux = False
for pp in ppl:
@ -1160,7 +1170,7 @@ def debugall(self):
print(f"tokens : {self.ppt},{self.pnt},{self.pt},{self.nt}")
print(f"ratios : {self.aratios}\n")
print(f"prompt : {self.pe}")
print(f"env : before15:{self.isbefore15},isxl:{self.isxl}")
print(f"env : before15:{self.isbefore15},isxl:{self.is_sdxl}")
print(f"loras{self.log}")
def bckeydealer(self, p):