932 lines
39 KiB
Python
932 lines
39 KiB
Python
import copy
|
|
from pprint import pprint
|
|
import torch
|
|
from modules import devices, shared, extra_networks, sd_hijack
|
|
from modules.script_callbacks import CFGDenoisedParams, CFGDenoiserParams
|
|
from torchvision.transforms import InterpolationMode, Resize # Mask.
|
|
import scripts.attention as att
|
|
from scripts.regions import floatdef
|
|
from scripts.attention import makerrandman
|
|
|
|
from modules import launch_utils
|
|
forge = launch_utils.git_tag()[0:2] == "f2" or launch_utils.git_tag() == "neo"
|
|
reforge = launch_utils.git_tag()[0:2] == "f1" or launch_utils.git_tag() == "classic"
|
|
|
|
if forge:
|
|
from modules.script_callbacks import AfterCFGCallbackParams, on_cfg_after_cfg
|
|
|
|
denoised_params = AfterCFGCallbackParams if forge else CFGDenoisedParams
|
|
|
|
islora = True
|
|
in_hr = False
|
|
layer_name = "lora_layer_name"
|
|
orig_Linear_forward = None
|
|
|
|
orig_lora_functional = False
|
|
|
|
lactive = False
|
|
labug =False
|
|
MINID = 1000
|
|
MAXID = 10000
|
|
LORAID = MINID # Discriminator for repeated lora usage / across gens, presumably.
|
|
|
|
def setuploras(self):
|
|
global lactive, labug, islora, orig_Linear_forward, orig_lora_functional, layer_name
|
|
lactive = True
|
|
labug = self.debug
|
|
islora = self.isbefore15
|
|
layer_name = self.layer_name
|
|
orig_lora_functional = orig_lora_functional = shared.opts.lora_functional if hasattr(shared.opts,"lora_functional") else False
|
|
|
|
try:
|
|
if 150 <= self.ui_version <= 159 or self.slowlora:
|
|
shared.opts.lora_functional = False
|
|
else:
|
|
shared.opts.lora_functional = True
|
|
except:
|
|
pass
|
|
|
|
is15 = 150 <= self.ui_version <= 159
|
|
orig_Linear_forward = torch.nn.Linear.forward
|
|
torch.nn.Linear.forward = h15_Linear_forward if is15 else h_Linear_forward
|
|
|
|
if forge:
|
|
shared.sd_model.forge_objects.unet.set_model_unet_function_wrapper(lambda apply, params: denoised_callback_s(apply, params, p3=self))
|
|
from backend.args import dynamic_args
|
|
dynamic_args["online_lora"] = True
|
|
import networks as net
|
|
net.load_networks = load_networks
|
|
|
|
for name, module in shared.sd_model.forge_objects.clip.cond_stage_model.clip_l.named_modules():
|
|
if name == "transformer.text_model.encoder.layers.0.self_attn.q_proj":
|
|
module.forward = forge_linear_forward.__get__(module)
|
|
if reforge:
|
|
shared.sd_model.forge_objects.unet.set_model_unet_function_wrapper(lambda apply, params: denoised_callback_s(apply, params, p3=self))
|
|
|
|
def cloneparams(orig,target):
|
|
target.x = orig.x.clone()
|
|
target.image_cond = orig.image_cond.clone()
|
|
target.sigma = orig.sigma.clone()
|
|
|
|
###################################################
|
|
###### Latent Method denoise call back
|
|
# Using the AND syntax with shared.batch_cond_uncond = False
|
|
# the U-NET is calculated (the number of prompts divided by AND) + 1 times.
|
|
# This means that the calculation is performed for the area + 1 times.
|
|
# This mechanism is used to apply LoRA by region by changing the LoRA application rate for each U-NET calculation.
|
|
# The problem here is that in the web-ui system, if more than two batch sizes are set,
|
|
# a problem will occur if the number of areas and the batch size are not the same.
|
|
# If the batch is 1 for 3 areas, the calculation is performed 4 times: Area1, Area2, Area3, and Negative.
|
|
# However, if the batch is 2,
|
|
# [Batch1-Area1, Batch1-Area2]
|
|
# [Batch1-Area3, Batch2-Area1]
|
|
# [Batch2-Area2, Batch2-Area3]
|
|
# [Batch1-Negative, Batch2-Negative]
|
|
# and the areas of simultaneous computation will be different.
|
|
# Therefore, it is necessary to change the order in advance.
|
|
# [Batch1-Area1, Batch1-Area2] -> [Batch1-Area1, Batch2-Area1]
|
|
# [Batch1-Area3, Batch2-Area1] -> [Batch1-Area2, Batch2-Area2]
|
|
# [Batch2-Area2, Batch2-Area3] -> [Batch1-Area3, Batch2-Area3]
|
|
|
|
def denoiser_callback_s(self, params: CFGDenoiserParams):
|
|
self.step = params.sampling_step
|
|
self.total_step = params.total_sampling_steps
|
|
|
|
self.pn = self.pn_s
|
|
|
|
if hasattr(params,"text_cond") and params.text_cond is not None:
|
|
if "DictWithShape" in params.text_cond.__class__.__name__:
|
|
self.cshape = params.text_cond[list(params.text_cond.keys())[0]].shape[1]
|
|
else:
|
|
self.cshape = params.text_cond.shape[1]
|
|
|
|
if hasattr(params,"text_uncond") and params.text_uncond is not None:
|
|
if "DictWithShape" in params.text_uncond.__class__.__name__:
|
|
self.ucshape = params.text_uncond[list(params.text_uncond.keys())[0]].shape[1]
|
|
else:
|
|
self.ucshape = params.text_uncond.shape[1]
|
|
|
|
if self.only_r and not self.diff:
|
|
return
|
|
|
|
if self.hr_returner():
|
|
return
|
|
|
|
if "Pro" in self.mode: # in Prompt mode, make masks from sum of attension maps
|
|
if self.x == None : cloneparams(params,self) # return to step 0 if mask is ready
|
|
|
|
lim = 1 if self.is_sdxl else 3
|
|
|
|
if len(att.pmaskshw) > lim:
|
|
self.filters = []
|
|
for b in range(self.batch_size):
|
|
allmask = []
|
|
basemask = None
|
|
for t, th, bratio in zip(self.pe, self.th, self.bratios):
|
|
key = f"{t}-{b}"
|
|
_, _, mask = att.makepmask(att.pmasks[key], params.x.shape[2], params.x.shape[3], th, self.step,self.total_step, self.is_sdxl,bratio = bratio)
|
|
mask = mask.repeat(params.x.shape[1],1,1)
|
|
basemask = 1 - mask if basemask is None else basemask - mask
|
|
if self.ex:
|
|
for l in range(len(allmask)):
|
|
mt = allmask[l] - mask
|
|
allmask[l] = torch.where(mt > 0, 1,0)
|
|
allmask.append(mask)
|
|
if not self.ex:
|
|
sum = torch.stack(allmask, dim=0).sum(dim=0)
|
|
sum = torch.where(sum == 0, 1 , sum)
|
|
allmask = [mask / sum for mask in allmask]
|
|
basemask = torch.where(basemask > 0, 1, 0)
|
|
allmask.insert(0,basemask)
|
|
self.filters.extend(allmask)
|
|
att.maskready = True
|
|
|
|
for t, th, bratio in zip(self.pe, self.th, self.bratios):
|
|
allmask = []
|
|
for hw in att.pmaskshw:
|
|
masks = None
|
|
for b in range(self.batch_size):
|
|
key = f"{t}-{b}"
|
|
_, mask, _ = att.makepmask(att.pmasks[key], hw[0], hw[1], th, self.step,self.total_step, self.is_sdxl,bratio = bratio)
|
|
mask = mask.unsqueeze(0).unsqueeze(-1)
|
|
masks = mask if b ==0 else torch.cat((masks,mask),dim=0)
|
|
allmask.append(mask)
|
|
att.pmasksf[key] = allmask
|
|
att.maskready = True
|
|
|
|
if not self.rebacked:
|
|
cloneparams(self,params)
|
|
params.sampling_step = 0
|
|
self.rebacked = True
|
|
|
|
if "La" in self.calc:
|
|
self.condi = 0
|
|
global in_hr, regioner
|
|
regioner.step = params.sampling_step
|
|
in_hr = self.in_hr
|
|
regioner.u_count = 0
|
|
if "u_list" not in self.log.keys() and hasattr(regioner,"u_llist"):
|
|
self.log["u_list"] = regioner.u_llist.copy()
|
|
if "u_list_hr" not in self.log.keys() and hasattr(regioner,"u_llist") and in_hr:
|
|
self.log["u_list_hr"] = regioner.u_llist.copy()
|
|
xt = params.x.clone()
|
|
ict = params.image_cond.clone()
|
|
st = params.sigma.clone()
|
|
batch = self.batch_size
|
|
areas = xt.shape[0] // batch -1
|
|
if forge: return
|
|
# SBM Stale version workaround.
|
|
if hasattr(params,"text_cond"):
|
|
if "DictWithShape" in params.text_cond.__class__.__name__:
|
|
ct = {}
|
|
for key in params.text_cond.keys():
|
|
ct[key] = params.text_cond[key].clone()
|
|
else:
|
|
ct = params.text_cond.clone()
|
|
|
|
for a in range(areas):
|
|
for b in range(batch):
|
|
params.x[b+a*batch] = xt[a + b * areas]
|
|
params.image_cond[b+a*batch] = ict[a + b * areas]
|
|
params.sigma[b+a*batch] = st[a + b * areas]
|
|
# SBM Stale version workaround.
|
|
if hasattr(params,"text_cond"):
|
|
if "DictWithShape" in params.text_cond.__class__.__name__:
|
|
for key in params.text_cond.keys():
|
|
params.text_cond[key][b+a*batch] = ct[key][a + b * areas]
|
|
else:
|
|
params.text_cond[b+a*batch] = ct[a + b * areas]
|
|
|
|
def denoised_callback_s(p1, p2 = None, p3 = None):
|
|
# if forge p1: model.apply(), p2: params, p3: script
|
|
# if A1111 p1: script, p2: DenoisedParams, p3: None
|
|
if p3 is not None:
|
|
self = p3
|
|
input_x = p2["input"]
|
|
timestep = p2["timestep"]
|
|
cond_or_uncond = p2["cond_or_uncond"]
|
|
c = p2["c"]
|
|
conds = c["c_crossattn"]
|
|
y = c["y"] if "y" in c else None
|
|
|
|
if not lactive and not self.diff:
|
|
return p1(input_x, timestep, **c)
|
|
|
|
length = len(cond_or_uncond)
|
|
batch = input_x.shape[0] // length
|
|
|
|
cond_or_uncond = cond_or_uncond * batch
|
|
|
|
region_list, orig_list = forge_make_chenge_list(batch, length)
|
|
|
|
outs = []
|
|
for i in range(length):
|
|
regioner.set_region(length - i - 1)
|
|
c["c_crossattn"] = conds[i*batch:i*batch+batch]
|
|
if y is not None:
|
|
c["y"] = y[i*batch:i*batch+batch]
|
|
outs.append(p1(input_x[i*batch:i*batch+batch], torch.cat([timestep[i:i+1]]*batch), **c))
|
|
|
|
output = torch.cat(outs)
|
|
|
|
x = output[region_list]
|
|
xt = x.clone()
|
|
areas = length - 1
|
|
else:
|
|
self = p1
|
|
params = p2
|
|
|
|
batch = self.batch_size
|
|
x = params.x
|
|
xt = params.x.clone()
|
|
areas = xt.shape[0] // batch - 1
|
|
|
|
#print(f"Denoised call back: calc {self.calc}, mode {self.mode}, rps {self.rps}, diff {self.diff}, filters {len(self.filters)} ")
|
|
|
|
if "La" in self.calc:
|
|
# x.shape = [batch_size, C, H // 8, W // 8]
|
|
|
|
if not "Pro" in self.mode:
|
|
indrebuild = self.filters == [] or self.filters[0].size() != x[0].size()
|
|
|
|
if indrebuild:
|
|
if "Ran" in self.mode:
|
|
if self.filters == []:
|
|
self.filters = [self.ranbase] + self.ransors if self.usebase else self.ransors
|
|
elif self.filters[0][:,:].size() != x[0,0,:,:].size():
|
|
self.filters = hrchange(self.ransors,x.shape[2], x.shape[3])
|
|
else:
|
|
if "Mask" in self.mode:
|
|
masks = (self.regmasks,self.regbase)
|
|
else:
|
|
masks = self.aratios #makefilters(c,h,w,masks,mode,usebase,bratios,indmask = None)
|
|
self.filters = makefilters(x.shape[1], x.shape[2], x.shape[3],masks, self.mode, self.usebase, self.bratios, "Mas" in self.mode)
|
|
self.filters = [f for f in self.filters]*batch
|
|
else:
|
|
if not att.maskready:
|
|
self.filters = [1,*[0 for a in range(areas - 1)]] * batch
|
|
|
|
if self.debug:
|
|
print("filterlength : ",len(self.filters))
|
|
print("x : ",x.shape)
|
|
print("areas : ",areas)
|
|
|
|
for b in range(batch):
|
|
for a in range(areas) :
|
|
fil = self.filters[a + b*areas] if not self.hr_returner() else 1
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
x[a + b * areas, :, :, :] = xt[b + a*batch, :, :, :] * fil + x[x.size()[0]+(b-batch), :, :, :] * (1 - fil)
|
|
|
|
if self.total_step == self.step + 2:
|
|
if self.rps is not None and self.diff:
|
|
if self.rps.latent is None:
|
|
self.rps.latent = x.clone()
|
|
return x[orig_list] if p3 is not None else None
|
|
elif self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None:
|
|
self.rps.latent_hr = x.clone()
|
|
return x[orig_list] if p3 is not None else None
|
|
else:
|
|
for b in range(batch):
|
|
for a in range(areas) :
|
|
fil = self.filters[a+1]
|
|
orig = self.rps.latent if self.rps.latent.shape[2:] == x.shape[2:] else self.rps.latent_hr
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
#print("1",type(self.rps.latent),type(fil))
|
|
x[:,:,:,:] = orig[:,:,:,:] * (1 - fil) + x[:,:,:,:] * fil
|
|
|
|
#if params.total_sampling_steps - 7 == params.sampling_step + 2:
|
|
if att.maskready:
|
|
if self.rps is not None and self.diff:
|
|
if self.rps.latent is not None:
|
|
if self.rps.latent.shape[2:] != x.shape[2:]:
|
|
if self.rps.latent_hr is None: return x[orig_list] if p3 is not None else None
|
|
for b in range(batch):
|
|
for a in range(areas) :
|
|
fil = self.filters[a+1]
|
|
orig = self.rps.latent if self.rps.latent.shape[2:] == x.shape[2:] else self.rps.latent_hr
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
#print("2",type(self.rps.latent),type(fil))
|
|
x[:,:,:,:] = orig[:,:,:,:] * (1 - fil) + x[:,:,:,:] * fil
|
|
|
|
if self.step == 0 and self.in_hr:
|
|
if self.rps is not None and self.diff:
|
|
if self.rps.latent is not None:
|
|
if self.rps.latent.shape[2:] != x.shape[2:] and self.rps.latent_hr is None: return x[orig_list] if p3 is not None else None
|
|
for b in range(batch):
|
|
for a in range(areas) :
|
|
fil = self.filters[a+1]
|
|
orig = self.rps.latent if self.rps.latent.shape[2:] == x.shape[2:] else self.rps.latent_hr
|
|
if self.debug : print(f"x = {x.size()}i = {a + b*areas}, j = {b + a*batch}, cond = {a + b*areas},filsum = {fil if type(fil) is int else torch.sum(fil)}, uncon = {x.size()[0]+(b-batch)}")
|
|
#print("3",type(self.rps.latent),type(fil))
|
|
x[:,:,:,:] = orig[:,:,:,:] * (1 - fil) + x[:,:,:,:] * fil
|
|
|
|
if p3 is not None: #forge
|
|
out = x[orig_list]
|
|
return out
|
|
|
|
def forge_make_chenge_list(batch, length):
|
|
orig = [x for x in range(batch*length)]
|
|
|
|
chunks_1 = [orig[i:i + batch] for i in range(0, len(orig), batch)]
|
|
chunks_2 = [[(i) + (length - 1) * x for x in range(batch)] for i in range(length)]
|
|
|
|
out1, out2, = [], []
|
|
|
|
for c1 in chunks_1[::-1]:
|
|
out1.extend(c1)
|
|
|
|
out2.extend(chunks_1[-1])
|
|
|
|
for c2 in chunks_2[:-1][::-1]:
|
|
out2.extend(c2)
|
|
|
|
return out1, out2
|
|
|
|
|
|
######################################################
|
|
##### Latent Method
|
|
|
|
def hrchange(filters,h, w):
|
|
out = []
|
|
for filter in filters:
|
|
out.append(makerrandman(filter,h,w,True))
|
|
return out
|
|
|
|
# Remove tags from called lora names.
|
|
flokey = lambda x: (x.split("added_by_regional_prompter")[0]
|
|
.split("added_by_lora_block_weight")[0].split("_in_LBW")[0].split("_in_RP")[0])
|
|
|
|
def lora_namer(self, p, lnter, lnur):
|
|
ldict_u = {}
|
|
ldict_te = {}
|
|
lorder = [] # Loras call order for matching with u/te lists.
|
|
import lora as loraclass
|
|
name_to_hash = {}
|
|
for lora in loraclass.loaded_loras:
|
|
ldict_u[lora.network_on_disk.filename if forge else lora.name] =lora.multiplier if self.isbefore15 else lora.unet_multiplier
|
|
ldict_te[lora.network_on_disk.filename if forge else lora.name] =lora.multiplier if self.isbefore15 else lora.te_multiplier
|
|
name_to_hash[lora.network_on_disk.alias] = lora.network_on_disk.filename
|
|
name_to_hash[lora.network_on_disk.name] = lora.network_on_disk.filename
|
|
|
|
subprompts = self.current_prompts[0].split("AND")
|
|
ldictlist_u =[ldict_u.copy() for i in range(len(subprompts)+1)]
|
|
ldictlist_te =[ldict_te.copy() for i in range(len(subprompts)+1)]
|
|
|
|
for i, prompt in enumerate(subprompts):
|
|
_, extranets = extra_networks.parse_prompts([prompt])
|
|
calledloras = extranets["lora"]
|
|
|
|
names = ""
|
|
tdict = {}
|
|
|
|
for called in calledloras:
|
|
names = names + name_to_hash[called.items[0]] if forge else names + called.items[0]
|
|
tdict[name_to_hash[called.items[0]] if forge else called.items[0]] = syntaxdealer(called.items,"unet=",1)
|
|
|
|
for key in ldictlist_u[i].keys():
|
|
shin_key = flokey(key)
|
|
if shin_key in names:
|
|
ldictlist_u[i+1][key] = float(tdict[shin_key])
|
|
ldictlist_te[i+1][key] = float(tdict[shin_key])
|
|
if key not in lorder:
|
|
lorder.append(key)
|
|
else:
|
|
ldictlist_u[i+1][key] = 0
|
|
ldictlist_te[i+1][key] = 0
|
|
|
|
if self.debug: print("Regioner lorder: ",lorder)
|
|
global regioner
|
|
regioner.__init__(self.lstop,self.lstop_hr)
|
|
u_llist = [d.copy() for d in ldictlist_u[1:]]
|
|
u_llist.append(ldictlist_u[0].copy())
|
|
regioner.te_llist = ldictlist_te
|
|
regioner.u_llist = u_llist
|
|
regioner.ndeleter(lnter, lnur, lorder)
|
|
if self.debug:
|
|
print("LoRA regioner : TE list",regioner.te_llist)
|
|
print("LoRA regioner : U list",regioner.u_llist)
|
|
|
|
def syntaxdealer(items,type,index): #type "unet=", "x=", "lwbe="
|
|
for item in items:
|
|
if type in item:
|
|
if "@" in item:return 1 #for loractl
|
|
return item.replace(type,"")
|
|
return items[index] if "@" not in items[index] else 1
|
|
|
|
def makefilters(c,h,w,masks,mode,usebase,bratios,indmask):
|
|
if indmask:
|
|
(regmasks, regbase) = masks
|
|
|
|
filters = []
|
|
x = torch.zeros(c, h, w).to(devices.device)
|
|
if usebase:
|
|
x0 = torch.zeros(c, h, w).to(devices.device)
|
|
i=0
|
|
if indmask:
|
|
ftrans = Resize((h, w), interpolation = InterpolationMode("nearest"))
|
|
for rmask, bratio in zip(regmasks,bratios[0]):
|
|
# Resize mask to current dims.
|
|
# Since it's a mask, we prefer a binary value, nearest is the only option.
|
|
rmask2 = ftrans(rmask.reshape([1, *rmask.shape])) # Requires dimensions N,C,{d}.
|
|
rmask2 = rmask2.reshape([1, h, w])
|
|
fx = x.clone()
|
|
if usebase:
|
|
fx[:,:,:] = fx + rmask2 * (1 - bratio)
|
|
x0[:,:,:] = x0 + rmask2 * bratio
|
|
else:
|
|
fx[:,:,:] = fx + rmask2 * 1
|
|
filters.append(fx)
|
|
|
|
if usebase: # Add base to x0.
|
|
rmask = regbase
|
|
rmask2 = ftrans(rmask.reshape([1, *rmask.shape])) # Requires dimensions N,C,{d}.
|
|
rmask2 = rmask2.reshape([1, h, w])
|
|
x0 = x0 + rmask2
|
|
else:
|
|
for drow in masks:
|
|
for dcell in drow.cols:
|
|
fx = x.clone()
|
|
if "Horizontal" in mode:
|
|
if usebase:
|
|
fx[:,int(h*drow.st):int(h*drow.ed),int(w*dcell.st):int(w*dcell.ed)] = 1 - dcell.base
|
|
x0[:,int(h*drow.st):int(h*drow.ed),int(w*dcell.st):int(w*dcell.ed)] = dcell.base
|
|
else:
|
|
fx[:,int(h*drow.st):int(h*drow.ed),int(w*dcell.st):int(w*dcell.ed)] = 1
|
|
elif "Vertical" in mode:
|
|
if usebase:
|
|
fx[:,int(h*dcell.st):int(h*dcell.ed),int(w*drow.st):int(w*drow.ed)] = 1 - dcell.base
|
|
x0[:,int(h*dcell.st):int(h*dcell.ed),int(w*drow.st):int(w*drow.ed)] = dcell.base
|
|
else:
|
|
fx[:,int(h*dcell.st):int(h*dcell.ed),int(w*drow.st):int(w*drow.ed)] = 1
|
|
filters.append(fx)
|
|
i +=1
|
|
if usebase : filters.insert(0,x0)
|
|
if labug : print(i,len(filters))
|
|
|
|
return filters
|
|
|
|
######################################################
|
|
##### Latent Method LoRA changer
|
|
|
|
TE_START_NAME = "transformer_text_model_encoder_layers_0_self_attn_q_proj"
|
|
UNET_START_NAME = "diffusion_model_time_embed_0"
|
|
|
|
TE_START_NAME_XL = "0_transformer_text_model_encoder_layers_0_self_attn_q_proj"
|
|
|
|
class LoRARegioner:
|
|
|
|
def __init__(self,stop=0,stop_hr=0):
|
|
self.te_count = 0
|
|
self.u_count = 0
|
|
self.te_llist = [{}]
|
|
self.u_llist = [{}]
|
|
self.mlist = {}
|
|
self.ctl = False
|
|
self.step = 0
|
|
self.stop = stop
|
|
self.stop_hr = stop_hr
|
|
self.stopped = False
|
|
self.stopped_hr = False
|
|
self.orig_weight = {}
|
|
|
|
try:
|
|
import lora_ctl_network as ctl
|
|
self.ctlweight = copy.deepcopy(ctl.lora_weights)
|
|
for set in self.ctlweight.values():
|
|
for weight in set.values():
|
|
if type(weight) == list:
|
|
self.ctl = True
|
|
except:
|
|
pass
|
|
|
|
def expand_del(self, val, lorder):
|
|
"""Broadcast single / comma separated val to lora list.
|
|
|
|
"""
|
|
lval = val.split(",")
|
|
if len(lval) > len(lorder):
|
|
lval = lval[:len(lorder)]
|
|
lval = [floatdef(v, 0) for v in lval]
|
|
if len(lval) < len(lorder): # Propagate difference.
|
|
lval.extend([lval[-1]] * (len(lorder) - len(lval)))
|
|
return lval
|
|
|
|
def ndeleter(self, lnter, lnur, lorder = None):
|
|
"""Multiply global weights by 0:1 factor.
|
|
|
|
Can be any value, negative too, but doesn't help much.
|
|
"""
|
|
if lorder is None:
|
|
lkeys = self.te_llist[0].keys()
|
|
else:
|
|
lkeys = lorder
|
|
lnter = self.expand_del(lnter, lkeys)
|
|
for (key, val) in zip(lkeys, lnter):
|
|
self.te_llist[0][key] *= val
|
|
if lorder is None:
|
|
lkeys = self.u_llist[-1].keys()
|
|
else:
|
|
lkeys = lorder
|
|
lnur = self.expand_del(lnur, lkeys)
|
|
for (key, val) in zip(lkeys, lnur):
|
|
self.u_llist[-1][key] *= val
|
|
|
|
def search_key(self,lora,i,xlist):
|
|
lorakey = lora.loaded_loras[i].name
|
|
if lorakey not in xlist.keys():
|
|
shin_key = flokey(lorakey)
|
|
picked = False
|
|
for mlkey in xlist.keys():
|
|
if mlkey.startswith(shin_key):
|
|
lorakey = mlkey
|
|
picked = True
|
|
if not picked:
|
|
print(f"key is not found in:{xlist.keys()}")
|
|
return lorakey
|
|
|
|
def te_start(self):
|
|
self.mlist = self.te_llist[self.te_count % len(self.te_llist)]
|
|
if self.mlist == {}: return
|
|
self.te_count += 1
|
|
import lora
|
|
for i in range(len(lora.loaded_loras)):
|
|
lorakey = self.search_key(lora,i,self.mlist)
|
|
lora.loaded_loras[i].multiplier = self.mlist[lorakey]
|
|
lora.loaded_loras[i].te_multiplier = self.mlist[lorakey]
|
|
|
|
def u_start(self):
|
|
if labug : print("u_count",self.u_count ,"u_count '%' divide", self.u_count % len(self.u_llist))
|
|
self.mlist = self.u_llist[self.u_count % len(self.u_llist)]
|
|
if self.mlist == {}: return
|
|
self.u_count += 1
|
|
|
|
stopstep = self.stop_hr if in_hr else self.stop
|
|
|
|
import lora
|
|
for i in range(len(lora.loaded_loras)):
|
|
lorakey = self.search_key(lora,i,self.mlist)
|
|
lora.loaded_loras[i].multiplier = 0 if self.step + 2 > stopstep and stopstep else self.mlist[lorakey]
|
|
lora.loaded_loras[i].unet_multiplier = 0 if self.step + 2 > stopstep and stopstep else self.mlist[lorakey]
|
|
if labug :print(lorakey,lora.loaded_loras[i].multiplier,lora.loaded_loras[i].multiplier )
|
|
if self.ctl:
|
|
import lora_ctl_network as ctl
|
|
key = "hrunet" if in_hr else "unet"
|
|
if self.mlist[lorakey] == 0 or (self.step + 2 > stopstep and stopstep):
|
|
ctl.lora_weights[lorakey][key] = [[0],[0]]
|
|
if labug :print(ctl.lora_weights[lorakey])
|
|
else:
|
|
if key in self.ctlweight[lorakey].keys():
|
|
ctl.lora_weights[lorakey][key] = self.ctlweight[lorakey][key]
|
|
else:
|
|
ctl.lora_weights[lorakey][key] = self.ctlweight[lorakey]["unet"]
|
|
if labug :print(ctl.lora_weights[lorakey])
|
|
|
|
def reset(self):
|
|
self.te_count = 0
|
|
self.u_count = 0
|
|
self.stopped = False
|
|
self.stopped_hr = False
|
|
self.orig_weight = {}
|
|
|
|
def te_start_f(self):
|
|
self.mlist = self.te_llist[self.te_count % len(self.te_llist)]
|
|
if self.mlist == {}: return
|
|
self.te_count += 1
|
|
|
|
if labug:
|
|
print(f"Set LoRA for Region {self.te_count % len(self.te_llist)}, u_count",self.u_count ,"u_count '%' divide", self.u_count % len(self.u_llist))
|
|
print(self.mlist)
|
|
|
|
lora_lorader = shared.sd_model.forge_objects.clip.patcher.lora_loader
|
|
lora_patches = shared.sd_model.forge_objects.clip.patcher.lora_patches
|
|
offload_device = shared.sd_model.forge_objects.clip.patcher.offload_device
|
|
|
|
for lora_key, patch in lora_patches.items():
|
|
for list_key in self.mlist:
|
|
if list_key in lora_key[0]:
|
|
if labug:
|
|
print(f"LoRA {lora_key} detected in {self.mlist}")
|
|
for patch_key in patch:
|
|
if patch_key + list_key not in self.orig_weight:
|
|
self.orig_weight[patch_key + list_key] = patch[patch_key][0][0]
|
|
patch[patch_key][0][0] = self.orig_weight[patch_key + list_key] * self.mlist[list_key]
|
|
|
|
refresh(lora_lorader, lora_patches=lora_patches, offload_device=offload_device)
|
|
|
|
def set_region(self, region):
|
|
if self.u_llist == [{}]: return
|
|
self.mlist = self.u_llist[region]
|
|
if labug:
|
|
print(f"Set LoRA for Region {region}, u_count",self.u_count ,"u_count '%' divide", self.u_count % len(self.u_llist))
|
|
print(self.mlist)
|
|
if self.mlist == {}: return
|
|
|
|
strengths = list(self.mlist.values())
|
|
|
|
def set_strengths(strengths):
|
|
for name, module in shared.sd_model.forge_objects.unet.model.named_modules():
|
|
patches = getattr(module, 'forge_online_loras', None)
|
|
weight_patches, bias_patches = None, None
|
|
if patches is not None:
|
|
weight_patches = patches.get('weight', None)
|
|
if weight_patches:
|
|
if len(weight_patches) != len(strengths) :
|
|
continue
|
|
for i in range(len(strengths)):
|
|
if name not in self.orig_weight:
|
|
self.orig_weight[name] = [x[0] for x in weight_patches]
|
|
weight_patches[i][0] = strengths[i] * self.orig_weight[name][i]
|
|
|
|
stopstep = self.stop_hr if in_hr else self.stop
|
|
if self.step >= stopstep:
|
|
if (self.stopped_hr if in_hr else self.stopped):
|
|
return
|
|
else:
|
|
set_strengths(0)
|
|
if in_hr:
|
|
self.stopped_hr = True
|
|
else:
|
|
self.stopped = True
|
|
|
|
set_strengths(strengths)
|
|
|
|
regioner = LoRARegioner()
|
|
|
|
############################################################
|
|
##### for new lora apply method in web-ui
|
|
|
|
def h_Linear_forward(self, input):
|
|
changethelora(getattr(self, layer_name, None))
|
|
if islora:
|
|
import lora
|
|
return lora.lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
|
elif forge or reforge:
|
|
return orig_Linear_forward(self, input)
|
|
else:
|
|
import networks
|
|
if shared.opts.lora_functional:
|
|
return networks.network_forward(self, input, networks.originals.Linear_forward)
|
|
networks.network_apply_weights(self)
|
|
return networks.originals.Linear_forward(self, input)
|
|
|
|
def h15_Linear_forward(self, input):
|
|
changethelora(getattr(self, layer_name, None))
|
|
if islora:
|
|
import lora
|
|
return lora.lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
|
else:
|
|
import networks
|
|
if shared.opts.lora_functional:
|
|
return networks.network_forward(self, input, networks.network_Linear_forward)
|
|
networks.network_apply_weights(self)
|
|
return torch.nn.Linear_forward_before_network(self, input)
|
|
|
|
def forge_linear_forward(self, x):
|
|
regioner.te_start_f()
|
|
from backend import operations as op
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = op.weights_manual_cast(self, x)
|
|
with op.main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
else:
|
|
weight, bias = op.get_weight_and_bias(self)
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
def changethelora(name):
|
|
if lactive:
|
|
global regioner
|
|
if name == TE_START_NAME or name == TE_START_NAME_XL:
|
|
regioner.te_start()
|
|
elif name == UNET_START_NAME:
|
|
regioner.u_start()
|
|
|
|
LORAANDSOON = {
|
|
"LoraHadaModule" : "w1a",
|
|
"LycoHadaModule" : "w1a",
|
|
"NetworkModuleHada": "w1a",
|
|
"FullModule" : "weight",
|
|
"NetworkModuleFull": "weight",
|
|
"IA3Module" : "w",
|
|
"NetworkModuleIa3" : "w",
|
|
"LoraKronModule" : "w1",
|
|
"LycoKronModule" : "w1",
|
|
"NetworkModuleLokr": "w1",
|
|
}
|
|
|
|
def changethedevice(module):
|
|
ltype = type(module).__name__
|
|
if ltype == "LoraUpDownModule" or ltype == "LycoUpDownModule" :
|
|
if hasattr(module,"up_model") :
|
|
module.up_model.weight = torch.nn.Parameter(module.up_model.weight.to(devices.device, dtype = torch.float))
|
|
module.down_model.weight = torch.nn.Parameter(module.down_model.weight.to(devices.device, dtype=torch.float))
|
|
else:
|
|
module.up.weight = torch.nn.Parameter(module.up.weight.to(devices.device, dtype = torch.float))
|
|
if hasattr(module.down, "weight"):
|
|
module.down.weight = torch.nn.Parameter(module.down.weight.to(devices.device, dtype=torch.float))
|
|
|
|
elif ltype == "LoraHadaModule" or ltype == "LycoHadaModule" or ltype == "NetworkModuleHada":
|
|
module.w1a = torch.nn.Parameter(module.w1a.to(devices.device, dtype=torch.float))
|
|
module.w1b = torch.nn.Parameter(module.w1b.to(devices.device, dtype=torch.float))
|
|
module.w2a = torch.nn.Parameter(module.w2a.to(devices.device, dtype=torch.float))
|
|
module.w2b = torch.nn.Parameter(module.w2b.to(devices.device, dtype=torch.float))
|
|
|
|
if module.t1 is not None:
|
|
module.t1 = torch.nn.Parameter(module.t1.to(devices.device, dtype=torch.float))
|
|
|
|
if module.t2 is not None:
|
|
module.t2 = torch.nn.Parameter(module.t2.to(devices.device, dtype=torch.float))
|
|
|
|
elif ltype == "FullModule" or ltype == "NetworkModuleFull":
|
|
module.weight = torch.nn.Parameter(module.weight.to(devices.device, dtype=torch.float))
|
|
|
|
if hasattr(module, 'bias') and module.bias != None:
|
|
module.bias = torch.nn.Parameter(module.bias.to(devices.device, dtype=torch.float))
|
|
|
|
def unloadlorafowards(self):
|
|
global orig_Linear_forward, lactive, labug
|
|
lactive = labug = False
|
|
|
|
try:
|
|
shared.opts.lora_functional = orig_lora_functional
|
|
except:
|
|
pass
|
|
|
|
import lora
|
|
if forge:
|
|
from backend.args import dynamic_args
|
|
dynamic_args["online_lora"] = self.orig_online_lora
|
|
else:
|
|
emb_db = sd_hijack.model_hijack.embedding_db
|
|
for net in lora.loaded_loras:
|
|
if hasattr(net,"bundle_embeddings"):
|
|
for emb_name, embedding in net.bundle_embeddings.items():
|
|
if embedding.loaded:
|
|
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
|
|
|
lora.loaded_loras.clear()
|
|
if orig_Linear_forward != None :
|
|
torch.nn.Linear.forward = orig_Linear_forward
|
|
orig_Linear_forward = None
|
|
|
|
def refresh(self, lora_patches, offload_device=torch.device('cpu')):
|
|
from backend.patcher import lora
|
|
from backend import utils, memory_management, operations
|
|
hashes = str(list(lora_patches.keys()))
|
|
|
|
# Merge Patches
|
|
|
|
all_patches = {}
|
|
|
|
for (_, _, _, online_mode), patches in lora_patches.items():
|
|
for key, current_patches in patches.items():
|
|
all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches
|
|
|
|
# Initialize
|
|
|
|
memory_management.signal_empty_cache = True
|
|
|
|
parameter_devices = lora.get_parameter_devices(self.model)
|
|
|
|
# Restore
|
|
|
|
for m in set(self.online_backup):
|
|
del m.forge_online_loras
|
|
|
|
self.online_backup = []
|
|
|
|
for k, w in self.backup.items():
|
|
if not isinstance(w, torch.nn.Parameter):
|
|
# In very few cases
|
|
w = torch.nn.Parameter(w, requires_grad=False)
|
|
|
|
utils.set_attr_raw(self.model, k, w)
|
|
|
|
self.backup = {}
|
|
|
|
lora.set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
|
|
|
# Patch
|
|
|
|
for (key, online_mode), current_patches in all_patches.items():
|
|
try:
|
|
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
|
assert isinstance(weight, torch.nn.Parameter)
|
|
except:
|
|
raise ValueError(f"Wrong LoRA Key: {key}")
|
|
|
|
if online_mode:
|
|
if not hasattr(parent_layer, 'forge_online_loras'):
|
|
parent_layer.forge_online_loras = {}
|
|
|
|
parent_layer.forge_online_loras[child_key] = current_patches
|
|
self.online_backup.append(parent_layer)
|
|
continue
|
|
|
|
if key not in self.backup:
|
|
self.backup[key] = weight.to(device=offload_device)
|
|
|
|
bnb_layer = None
|
|
|
|
if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
|
|
bnb_layer = parent_layer
|
|
from backend.operations_bnb import functional_dequantize_4bit
|
|
weight = functional_dequantize_4bit(weight)
|
|
|
|
gguf_cls = getattr(weight, 'gguf_cls', None)
|
|
gguf_parameter = None
|
|
|
|
if gguf_cls is not None:
|
|
gguf_parameter = weight
|
|
from backend.operations_gguf import dequantize_tensor
|
|
weight = dequantize_tensor(weight)
|
|
|
|
try:
|
|
weight = lora.merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
except:
|
|
print('Patching LoRA weights out of memory. Retrying by offloading models.')
|
|
lora.set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
|
|
memory_management.soft_empty_cache()
|
|
weight = lora.merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
|
|
if bnb_layer is not None:
|
|
bnb_layer.reload_weight(weight)
|
|
continue
|
|
|
|
if gguf_cls is not None:
|
|
gguf_cls.quantize_pytorch(weight, gguf_parameter)
|
|
continue
|
|
|
|
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
|
|
|
# End
|
|
|
|
lora.set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
|
self.loaded_hash = hashes
|
|
return
|
|
|
|
|
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
|
from modules import sd_models
|
|
import networks as nets
|
|
from backend.args import dynamic_args
|
|
from modules import sd_models, errors
|
|
|
|
global lora_state_dict_cache
|
|
|
|
current_sd = sd_models.model_data.get_sd_model()
|
|
if current_sd is None:
|
|
return
|
|
|
|
nets.loaded_networks.clear()
|
|
|
|
unavailable_networks = []
|
|
for name in names:
|
|
if name.lower() in nets.forbidden_network_aliases and nets.available_networks.get(name) is None:
|
|
unavailable_networks.append(name)
|
|
elif nets.available_network_aliases.get(name) is None:
|
|
unavailable_networks.append(name)
|
|
|
|
if unavailable_networks:
|
|
nets.update_available_networks_by_names(unavailable_networks)
|
|
|
|
networks_on_disk = [nets.available_networks.get(name, None) if name.lower() in nets.forbidden_network_aliases else nets.available_network_aliases.get(name, None) for name in names]
|
|
if any(x is None for x in networks_on_disk):
|
|
nets.list_available_networks()
|
|
networks_on_disk = [nets.available_networks.get(name, None) if name.lower() in nets.forbidden_network_aliases else nets.available_network_aliases.get(name, None) for name in names]
|
|
|
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
|
try:
|
|
net = nets.load_network(name, network_on_disk)
|
|
except Exception as e:
|
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
|
continue
|
|
net.mentioned_name = name
|
|
network_on_disk.read_hash()
|
|
nets.loaded_networks.append(net)
|
|
|
|
online_mode = dynamic_args.get('online_lora', False)
|
|
|
|
if not current_sd.forge_objects.unet.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
online_mode = False
|
|
|
|
compiled_lora_targets = []
|
|
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
|
compiled_lora_targets.append([a.filename, b, c, online_mode])
|
|
|
|
compiled_lora_targets_hash = str(compiled_lora_targets)
|
|
|
|
if current_sd.current_lora_hash == compiled_lora_targets_hash:
|
|
return
|
|
|
|
current_sd.current_lora_hash = compiled_lora_targets_hash
|
|
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
|
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
|
|
|
for filename, strength_model, strength_clip, online_mode in compiled_lora_targets:
|
|
lora_sd = nets.load_lora_state_dict(filename)
|
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip = nets.load_lora_for_models(
|
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
|
filename=filename, online_mode=online_mode)
|
|
|
|
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
|
return |