hako-mikan 2025-06-23 19:06:41 +09:00
parent 6a8a69280b
commit be62345220
3 changed files with 32 additions and 10 deletions

View File

@ -130,6 +130,10 @@ def hook_forwards_x(self, root_module: torch.nn.Module, remove=False):
def hook_forward(self, module):
def forward(x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, value = None, transformer_options=None):
pndealer(self,context)
if self.hr_returner():
return main_forward(module, x, context, mask, x.shape[0] // self.batch_size, self.isvanilla,userpp =True,step = self.step, is_sdxl = self.is_sdxl)
if self.debug and self.count == 0:
print("\ninput : ", x.size())
print("tokens : ", context.size())

View File

@ -94,13 +94,13 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
self.pn = self.pn_s
if hasattr(params,"text_cond"):
if hasattr(params,"text_cond") and params.text_cond is not None:
if "DictWithShape" in params.text_cond.__class__.__name__:
self.cshape = params.text_cond[list(params.text_cond.keys())[0]].shape[1]
else:
self.cshape = params.text_cond.shape[1]
if hasattr(params,"text_uncond"):
if hasattr(params,"text_uncond") and params.text_uncond is not None:
if "DictWithShape" in params.text_uncond.__class__.__name__:
self.ucshape = params.text_uncond[list(params.text_uncond.keys())[0]].shape[1]
else:
@ -109,6 +109,9 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
if self.only_r and not self.diff:
return
if self.hr_returner():
return
if "Pro" in self.mode: # in Prompt mode, make masks from sum of attension maps
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
@ -270,7 +273,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
for b in range(batch):
for a in range(areas) :
fil = self.filters[a + b*areas]
fil = self.filters[a + b*areas] if not self.hr_returner() else 1
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
x[a + b * areas, :, :, :] = xt[b + a*batch, :, :, :] * fil + x[x.size()[0]+(b-batch), :, :, :] * (1 - fil)
@ -739,7 +742,7 @@ def changethedevice(module):
if hasattr(module, 'bias') and module.bias != None:
module.bias = torch.nn.Parameter(module.bias.to(devices.device, dtype=torch.float))
def unloadlorafowards(p):
def unloadlorafowards(self):
global orig_Linear_forward, lactive, labug
lactive = labug = False
@ -751,7 +754,7 @@ def unloadlorafowards(p):
import lora
if forge:
from backend.args import dynamic_args
dynamic_args["online_lora"] = False
dynamic_args["online_lora"] = self.orig_online_lora
else:
emb_db = sd_hijack.model_hijack.embedding_db
for net in lora.loaded_loras:

View File

@ -41,8 +41,10 @@ OPTUSEL = "Use LoHa or other"
OPTBREAK = "Use BREAK to change chunks"
OPTFLIP = "Flip prompts"
OPTCOUT = "Comment Out `#`"
OPTAHIRES = "Enabled only in Hires Fix"
OPTDHIRES = "Disabled in Hires Fix"
OPTIONLIST = [OPTAND,OPTUSEL,OPTBREAK,OPTFLIP,OPTCOUT,"debug", "debug2"]
OPTIONLIST = [OPTAND,OPTUSEL,OPTBREAK,OPTFLIP,OPTCOUT,OPTAHIRES,OPTDHIRES,"debug", "debug2"]
# Modules.basedir points to extension's dir. script_path or scripts.basedir points to root.
PTPRESET = modules.scripts.basedir()
@ -256,6 +258,8 @@ class Script(modules.scripts.Script):
self.in_hr = False
self.xsize = 0
self.imgcount = 0
self.hiresacts = [False, False]
self.orig_online_lora = False
# for latent mode
self.filters = []
self.lora_applied = False
@ -502,6 +506,10 @@ class Script(modules.scripts.Script):
if p.threshold is not None:threshold = str(p.threshold)
else:
diff = False
if forge:
from backend.args import dynamic_args
self.orig_online_lora = dynamic_args["online_lora"]
if not any(key in tprompt for key in ALLALLKEYS) or not active:
return unloader(self,p)
@ -538,6 +546,7 @@ class Script(modules.scripts.Script):
self.all_prompts = p.all_prompts.copy()
self.all_negative_prompts = p.all_negative_prompts.copy()
self.optbreak = OPTBREAK in options
self.hiresacts = [OPTAHIRES in options, OPTDHIRES in options]
# SBM ddim / plms detection.
self.isvanilla = p.sampler_name in ["DDIM", "PLMS", "UniPC"]
@ -583,7 +592,7 @@ class Script(modules.scripts.Script):
shared.opts.batch_cond_uncond = orig_batch_cond_uncond
else:
shared.batch_cond_uncond = orig_batch_cond_uncond
unloadlorafowards(p)
unloadlorafowards(self)
denoiserdealer(self, True)
else:
hook_forwards(self, p, remove = True)
@ -689,6 +698,14 @@ class Script(modules.scripts.Script):
def denoised_callback(self, params):
denoised_callback_s(self, params)
def hr_returner(self):
#0: enabled only hires, 1: disabled in hires
#print(self.in_hr,self.hiresacts,self.hiresacts[1] if self.in_hr else self.hiresacts[0])
if self.in_hr:
return self.hiresacts[1]
else:
return self.hiresacts[0]
def unloader(self,p):
if self.hooked:
@ -702,7 +719,7 @@ def unloader(self,p):
else:
shared.batch_cond_uncond = orig_batch_cond_uncond
unloadlorafowards(p)
unloadlorafowards(self)
def denoiserdealer(self, only_r):
if not hasattr(self,"dr_callbacks"):
@ -832,10 +849,8 @@ def tokendealer(self, p):
tokenizer = p.sd_model.text_processing_engine_l.tokenize_line
else:
tokenizer = p.sd_model.text_processing_engine.tokenize_line
self.flux = flux = "flux" in str(type(p.sd_model.forge_objects.unet.model.diffusion_model))
else:
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.is_sdxl else shared.sd_model.cond_stage_model.tokenize_line
self.flux = flux = False
for pp in ppl:
tokens, tokensnum = tokenizer(pp)