from typing import List, Dict import re import torch import composable_lora_step import composable_lycoris import plot_helper from modules import extra_networks, shared re_AND = re.compile(r"\bAND\b") def load_prompt_loras(prompt: str): global is_single_block global full_controllers global first_log_drawing prompt_loras.clear() prompt_blocks.clear() lora_controllers.clear() drawing_data.clear() full_controllers.clear() drawing_lora_names.clear() #load AND...AND block subprompts = re_AND.split(prompt) tmp_prompt_loras = [] tmp_prompt_blocks = [] for i, subprompt in enumerate(subprompts): loras = {} _, extra_network_data = extra_networks.parse_prompt(subprompt) for params in extra_network_data['lora']: name = params.items[0] multiplier = float(params.items[1]) if len(params.items) > 1 else 1.0 loras[name] = multiplier tmp_prompt_loras.append(loras) tmp_prompt_blocks.append(subprompt) is_single_block = (len(tmp_prompt_loras) == 1) #load [A:B:N] syntax if opt_composable_with_step: print("Loading LoRA step controller...") tmp_lora_controllers = composable_lora_step.parse_step_rendering_syntax(prompt) #for batches > 1 prompt_loras.extend(tmp_prompt_loras * num_batches) lora_controllers.extend(tmp_lora_controllers * num_batches) prompt_blocks.extend(tmp_prompt_blocks * num_batches) for controller_it in tmp_lora_controllers: full_controllers += controller_it first_log_drawing = False def reset_counters(): global text_model_encoder_counter global diffusion_model_counter global step_counter global should_print # reset counter to uc head text_model_encoder_counter = -1 diffusion_model_counter = 0 step_counter += 1 should_print = True def reset_step_counters(): global step_counter global should_print should_print = True step_counter = 0 def add_step_counters(): global step_counter global should_print should_print = True step_counter += 1 if step_counter > num_steps: step_counter = 0 else: if opt_plot_lora_weight: log_lora() def log_lora(): import lora tmp_data : List[float] = [] if len(lora.loaded_loras) <= 0: tmp_data = [0.0] if len(drawing_lora_names) <= 0: drawing_lora_names.append("LoRA Model Not Found.") for m_lora in lora.loaded_loras: current_lora = m_lora.name multiplier = m_lora.multiplier if opt_composable_with_step: multiplier = composable_lora_step.check_lora_weight(full_controllers, current_lora, step_counter, num_steps) index = -1 if current_lora in drawing_lora_names: index = drawing_lora_names.index(current_lora) else: index = len(drawing_lora_names) drawing_lora_names.append(current_lora) if index >= len(tmp_data): for i in range(len(tmp_data), index): tmp_data.append(0.0) tmp_data.append(multiplier) else: tmp_data[index] = multiplier drawing_data.append(tmp_data) def plot_lora(): max_size = -1 if len(drawing_data) < num_steps: item = drawing_data[len(drawing_data) - 1] if len(drawing_data) > 0 else [0.0] drawing_data.extend([item]*(num_steps - len(drawing_data))) drawing_data.insert(0, drawing_lora_first_index) for datalist in drawing_data: datalist_len = len(datalist) if datalist_len > max_size: max_size = datalist_len for i, datalist in enumerate(drawing_data): datalist_len = len(datalist) if datalist_len < max_size: drawing_data[i].extend([0.0]*(max_size - datalist_len)) return plot_helper.plot_lora_weight(drawing_data, drawing_lora_names) def lora_forward(compvis_module, input, res): global text_model_encoder_counter global diffusion_model_counter global step_counter global should_print global first_log_drawing global drawing_lora_first_index import lora if not first_log_drawing: first_log_drawing = True if enabled: print("Composable LoRA load successful.") if opt_plot_lora_weight: log_lora() drawing_lora_first_index = drawing_data[0] if len(lora.loaded_loras) == 0: return res lora_layer_name_loading : str | None = getattr(compvis_module, 'lora_layer_name', None) if lora_layer_name_loading is None: return res #let it type is actually a string lora_layer_name : str = str(lora_layer_name_loading) del lora_layer_name_loading num_loras = len(lora.loaded_loras) if text_model_encoder_counter == -1: text_model_encoder_counter = len(prompt_loras) * num_loras tmp_check_loras = [] #store which lora are already apply tmp_check_loras.clear() for m_lora in lora.loaded_loras: module = m_lora.modules.get(lora_layer_name, None) if module is None: #fix the lyCORIS issue composable_lycoris.check_lycoris_end_layer(lora_layer_name, res, num_loras) continue current_lora = m_lora.name lora_already_used = False for check_lora in tmp_check_loras: if current_lora == check_lora: #find the same lora, marked lora_already_used = True break #store the applied lora into list tmp_check_loras.append(current_lora) #if current lora already apply, skip this lora if lora_already_used == True: continue #support for lyCORIS patch = composable_lycoris.get_lora_patch(module, input, res) alpha = composable_lycoris.get_lora_alpha(module, 1.0) num_prompts = len(prompt_loras) # print(f"lora.name={m_lora.name} lora.mul={m_lora.multiplier} alpha={alpha} pat.shape={patch.shape}") if enabled: if lora_layer_name.startswith("transformer_"): # "transformer_text_model_encoder_" # if 0 <= text_model_encoder_counter // num_loras < len(prompt_loras): # c loras = prompt_loras[text_model_encoder_counter // num_loras] multiplier = loras.get(m_lora.name, 0.0) if multiplier != 0.0: # print(f"c #{text_model_encoder_counter // num_loras} lora.name={m_lora.name} mul={multiplier} lora_layer_name={lora_layer_name}") res += multiplier * alpha * patch else: # uc if (opt_uc_text_model_encoder or (is_single_block and (not opt_single_no_uc))) and m_lora.multiplier != 0.0: # print(f"uc #{text_model_encoder_counter // num_loras} lora.name={m_lora.name} lora.mul={m_lora.multiplier} lora_layer_name={lora_layer_name}") res += m_lora.multiplier * alpha * patch if lora_layer_name.endswith("_11_mlp_fc2"): # last lora_layer_name of text_model_encoder text_model_encoder_counter += 1 # c1 c1 c2 c2 .. .. uc uc if text_model_encoder_counter == (len(prompt_loras) + num_batches) * num_loras: text_model_encoder_counter = 0 elif lora_layer_name.startswith("diffusion_model_"): # "diffusion_model_" if res.shape[0] == num_batches * num_prompts + num_batches: # tensor.shape[1] == uncond.shape[1] tensor_off = 0 uncond_off = num_batches * num_prompts for b in range(num_batches): # c for p, loras in enumerate(prompt_loras): multiplier = loras.get(m_lora.name, 0.0) if opt_composable_with_step: prompt_block_id = p lora_controller = lora_controllers[prompt_block_id] multiplier = composable_lora_step.check_lora_weight(lora_controller, m_lora.name, step_counter, num_steps) if multiplier != 0.0: # print(f"tensor #{b}.{p} lora.name={m_lora.name} mul={multiplier} lora_layer_name={lora_layer_name}") res[tensor_off] += multiplier * alpha * patch[tensor_off] tensor_off += 1 # uc if (opt_uc_diffusion_model or (is_single_block and (not opt_single_no_uc))) and m_lora.multiplier != 0.0: # print(f"uncond lora.name={m_lora.name} lora.mul={m_lora.multiplier} lora_layer_name={lora_layer_name}") multiplier = m_lora.multiplier if is_single_block and opt_composable_with_step: multiplier = composable_lora_step.check_lora_weight(full_controllers, m_lora.name, step_counter, num_steps) res[uncond_off] += multiplier * alpha * patch[uncond_off] uncond_off += 1 else: # tensor.shape[1] != uncond.shape[1] cur_num_prompts = res.shape[0] base = (diffusion_model_counter // cur_num_prompts) // num_loras * cur_num_prompts prompt_len = len(prompt_loras) if 0 <= base < len(prompt_loras): # c for off in range(cur_num_prompts): if base + off < prompt_len: loras = prompt_loras[base + off] multiplier = loras.get(m_lora.name, 0.0) if opt_composable_with_step: prompt_block_id = base + off lora_controller = lora_controllers[prompt_block_id] multiplier = composable_lora_step.check_lora_weight(lora_controller, m_lora.name, step_counter, num_steps) if multiplier != 0.0: # print(f"c #{base + off} lora.name={m_lora.name} mul={multiplier} lora_layer_name={lora_layer_name}") res[off] += multiplier * alpha * patch[off] else: # uc if (opt_uc_diffusion_model or (is_single_block and (not opt_single_no_uc))) and m_lora.multiplier != 0.0: # print(f"uc {lora_layer_name} lora.name={m_lora.name} lora.mul={m_lora.multiplier}") multiplier = m_lora.multiplier if is_single_block and opt_composable_with_step: multiplier = composable_lora_step.check_lora_weight(full_controllers, m_lora.name, step_counter, num_steps) res += multiplier * alpha * patch if lora_layer_name.endswith("_11_1_proj_out"): # last lora_layer_name of diffusion_model diffusion_model_counter += cur_num_prompts # c1 c2 .. uc if diffusion_model_counter >= (len(prompt_loras) + num_batches) * num_loras: diffusion_model_counter = 0 add_step_counters() else: # default if m_lora.multiplier != 0.0: # print(f"default {lora_layer_name} lora.name={m_lora.name} lora.mul={m_lora.multiplier}") res += m_lora.multiplier * alpha * patch else: # default if m_lora.multiplier != 0.0: # print(f"DEFAULT {lora_layer_name} lora.name={m_lora.name} lora.mul={m_lora.multiplier}") res += m_lora.multiplier * alpha * patch return res def lora_Linear_forward(self, input): if (not self.weight.is_cuda) and input.is_cuda: #if variables not on the same device (between cpu and gpu) self_weight_cuda = self.weight.cuda() #pass to GPU to_del = self.weight self.weight = None #delete CPU variable del to_del del self.weight #avoid pytorch 2.0 throwing exception self.weight = self_weight_cuda #load GPU data to self.weight return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) def lora_Conv2d_forward(self, input): if (not self.weight.is_cuda) and input.is_cuda: self_weight_cuda = self.weight.cuda() to_del = self.weight self.weight = None del to_del del self.weight #avoid "cannot assign XXX as parameter YYY (torch.nn.Parameter or None expected)" self.weight = self_weight_cuda return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) def should_reload(): #pytorch 2.0 should reload match = re.search(r"\d+(\.\d+)?",str(torch.__version__)) if not match: return True ver = float(match.group(0)) return ver >= 2.0 enabled = False opt_composable_with_step = False opt_uc_text_model_encoder = False opt_uc_diffusion_model = False opt_plot_lora_weight = False opt_single_no_uc = False verbose = True drawing_lora_names : List[str] = [] drawing_data : List[List[float]] = [] drawing_lora_first_index : List[float] = [] first_log_drawing : bool = False is_single_block : bool = False num_batches: int = 0 num_steps: int = 20 prompt_loras: List[Dict[str, float]] = [] text_model_encoder_counter: int = -1 diffusion_model_counter: int = 0 step_counter: int = 0 should_print : bool = True prompt_blocks: List[str] = [] lora_controllers: List[List[composable_lora_step.LoRA_Controller_Base]] = [] full_controllers: List[composable_lora_step.LoRA_Controller_Base] = []