main
parent
6a8a69280b
commit
be62345220
|
|
@ -130,6 +130,10 @@ def hook_forwards_x(self, root_module: torch.nn.Module, remove=False):
|
|||
def hook_forward(self, module):
|
||||
def forward(x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, value = None, transformer_options=None):
|
||||
pndealer(self,context)
|
||||
|
||||
if self.hr_returner():
|
||||
return main_forward(module, x, context, mask, x.shape[0] // self.batch_size, self.isvanilla,userpp =True,step = self.step, is_sdxl = self.is_sdxl)
|
||||
|
||||
if self.debug and self.count == 0:
|
||||
print("\ninput : ", x.size())
|
||||
print("tokens : ", context.size())
|
||||
|
|
|
|||
|
|
@ -94,13 +94,13 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|||
|
||||
self.pn = self.pn_s
|
||||
|
||||
if hasattr(params,"text_cond"):
|
||||
if hasattr(params,"text_cond") and params.text_cond is not None:
|
||||
if "DictWithShape" in params.text_cond.__class__.__name__:
|
||||
self.cshape = params.text_cond[list(params.text_cond.keys())[0]].shape[1]
|
||||
else:
|
||||
self.cshape = params.text_cond.shape[1]
|
||||
|
||||
if hasattr(params,"text_uncond"):
|
||||
if hasattr(params,"text_uncond") and params.text_uncond is not None:
|
||||
if "DictWithShape" in params.text_uncond.__class__.__name__:
|
||||
self.ucshape = params.text_uncond[list(params.text_uncond.keys())[0]].shape[1]
|
||||
else:
|
||||
|
|
@ -109,6 +109,9 @@ def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|||
if self.only_r and not self.diff:
|
||||
return
|
||||
|
||||
if self.hr_returner():
|
||||
return
|
||||
|
||||
if "Pro" in self.mode: # in Prompt mode, make masks from sum of attension maps
|
||||
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
|
||||
|
||||
|
|
@ -270,7 +273,7 @@ def denoised_callback_s(p1, p2 = None, p3 = None):
|
|||
|
||||
for b in range(batch):
|
||||
for a in range(areas) :
|
||||
fil = self.filters[a + b*areas]
|
||||
fil = self.filters[a + b*areas] if not self.hr_returner() else 1
|
||||
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
||||
x[a + b * areas, :, :, :] = xt[b + a*batch, :, :, :] * fil + x[x.size()[0]+(b-batch), :, :, :] * (1 - fil)
|
||||
|
||||
|
|
@ -739,7 +742,7 @@ def changethedevice(module):
|
|||
if hasattr(module, 'bias') and module.bias != None:
|
||||
module.bias = torch.nn.Parameter(module.bias.to(devices.device, dtype=torch.float))
|
||||
|
||||
def unloadlorafowards(p):
|
||||
def unloadlorafowards(self):
|
||||
global orig_Linear_forward, lactive, labug
|
||||
lactive = labug = False
|
||||
|
||||
|
|
@ -751,7 +754,7 @@ def unloadlorafowards(p):
|
|||
import lora
|
||||
if forge:
|
||||
from backend.args import dynamic_args
|
||||
dynamic_args["online_lora"] = False
|
||||
dynamic_args["online_lora"] = self.orig_online_lora
|
||||
else:
|
||||
emb_db = sd_hijack.model_hijack.embedding_db
|
||||
for net in lora.loaded_loras:
|
||||
|
|
|
|||
|
|
@ -41,8 +41,10 @@ OPTUSEL = "Use LoHa or other"
|
|||
OPTBREAK = "Use BREAK to change chunks"
|
||||
OPTFLIP = "Flip prompts"
|
||||
OPTCOUT = "Comment Out `#`"
|
||||
OPTAHIRES = "Enabled only in Hires Fix"
|
||||
OPTDHIRES = "Disabled in Hires Fix"
|
||||
|
||||
OPTIONLIST = [OPTAND,OPTUSEL,OPTBREAK,OPTFLIP,OPTCOUT,"debug", "debug2"]
|
||||
OPTIONLIST = [OPTAND,OPTUSEL,OPTBREAK,OPTFLIP,OPTCOUT,OPTAHIRES,OPTDHIRES,"debug", "debug2"]
|
||||
|
||||
# Modules.basedir points to extension's dir. script_path or scripts.basedir points to root.
|
||||
PTPRESET = modules.scripts.basedir()
|
||||
|
|
@ -256,6 +258,8 @@ class Script(modules.scripts.Script):
|
|||
self.in_hr = False
|
||||
self.xsize = 0
|
||||
self.imgcount = 0
|
||||
self.hiresacts = [False, False]
|
||||
self.orig_online_lora = False
|
||||
# for latent mode
|
||||
self.filters = []
|
||||
self.lora_applied = False
|
||||
|
|
@ -502,6 +506,10 @@ class Script(modules.scripts.Script):
|
|||
if p.threshold is not None:threshold = str(p.threshold)
|
||||
else:
|
||||
diff = False
|
||||
|
||||
if forge:
|
||||
from backend.args import dynamic_args
|
||||
self.orig_online_lora = dynamic_args["online_lora"]
|
||||
|
||||
if not any(key in tprompt for key in ALLALLKEYS) or not active:
|
||||
return unloader(self,p)
|
||||
|
|
@ -538,6 +546,7 @@ class Script(modules.scripts.Script):
|
|||
self.all_prompts = p.all_prompts.copy()
|
||||
self.all_negative_prompts = p.all_negative_prompts.copy()
|
||||
self.optbreak = OPTBREAK in options
|
||||
self.hiresacts = [OPTAHIRES in options, OPTDHIRES in options]
|
||||
|
||||
# SBM ddim / plms detection.
|
||||
self.isvanilla = p.sampler_name in ["DDIM", "PLMS", "UniPC"]
|
||||
|
|
@ -583,7 +592,7 @@ class Script(modules.scripts.Script):
|
|||
shared.opts.batch_cond_uncond = orig_batch_cond_uncond
|
||||
else:
|
||||
shared.batch_cond_uncond = orig_batch_cond_uncond
|
||||
unloadlorafowards(p)
|
||||
unloadlorafowards(self)
|
||||
denoiserdealer(self, True)
|
||||
else:
|
||||
hook_forwards(self, p, remove = True)
|
||||
|
|
@ -689,6 +698,14 @@ class Script(modules.scripts.Script):
|
|||
|
||||
def denoised_callback(self, params):
|
||||
denoised_callback_s(self, params)
|
||||
|
||||
def hr_returner(self):
|
||||
#0: enabled only hires, 1: disabled in hires
|
||||
#print(self.in_hr,self.hiresacts,self.hiresacts[1] if self.in_hr else self.hiresacts[0])
|
||||
if self.in_hr:
|
||||
return self.hiresacts[1]
|
||||
else:
|
||||
return self.hiresacts[0]
|
||||
|
||||
def unloader(self,p):
|
||||
if self.hooked:
|
||||
|
|
@ -702,7 +719,7 @@ def unloader(self,p):
|
|||
else:
|
||||
shared.batch_cond_uncond = orig_batch_cond_uncond
|
||||
|
||||
unloadlorafowards(p)
|
||||
unloadlorafowards(self)
|
||||
|
||||
def denoiserdealer(self, only_r):
|
||||
if not hasattr(self,"dr_callbacks"):
|
||||
|
|
@ -832,10 +849,8 @@ def tokendealer(self, p):
|
|||
tokenizer = p.sd_model.text_processing_engine_l.tokenize_line
|
||||
else:
|
||||
tokenizer = p.sd_model.text_processing_engine.tokenize_line
|
||||
self.flux = flux = "flux" in str(type(p.sd_model.forge_objects.unet.model.diffusion_model))
|
||||
else:
|
||||
tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.is_sdxl else shared.sd_model.cond_stage_model.tokenize_line
|
||||
self.flux = flux = False
|
||||
|
||||
for pp in ppl:
|
||||
tokens, tokensnum = tokenizer(pp)
|
||||
|
|
|
|||
Loading…
Reference in New Issue