from typing import List, Union import re import ast import copy import json import math import sys import traceback import random from modules import extra_networks re_AND = re.compile(r"\bAND\b") class Runable: """ like exec() but can return values https://stackoverflow.com/a/52361938/5862977 """ def __init__(self, code : str, code_name : str = ""): self.code = code self.code_name = code_name self.compiled = False try: self.compile_self() except Exception: pass def compile_self(self): self.code_ast = ast.parse(self.code, self.code_name) self.init_ast = copy.deepcopy(self.code_ast) self.init_ast.body = self.code_ast.body[:-1] self.last_ast = copy.deepcopy(self.code_ast) self.last_ast.body = self.code_ast.body[-1:] self.full_bin = compile(self.code_ast, self.code_name, "exec") self.start_bin = compile(self.init_ast, self.code_name, "exec") if type(self.last_ast.body[0]) == ast.Expr: self.run_bin = compile(self.convertExpr2Expression(self.last_ast.body[0]), self.code_name, "eval") else: self.end_bin = compile(self.last_ast, self.code_name, "exec") self.compiled = True def convertExpr2Expression(self, expr : ast.Expr): expr.lineno = 0 expr.col_offset = 0 result = ast.Expression(expr.value, lineno=0, col_offset = 0) return result def run(self, module): if not self.compiled: self.compile_self() if len(self.init_ast.body) > 0: exec(self.start_bin, module.__dict__) if type(self.last_ast.body[0]) == ast.Expr: return eval(self.run_bin, module.__dict__) else: exec(self.end_bin, module.__dict__) class LoRA_data: def __init__(self, name : str, weight : float): self.name = name self.weight = weight def __repr__(self): return f"LoRA:{self.name}:{self.weight}" def __str__(self): return f"LoRA:{self.name}:{self.weight}" class LoRA_Weight_CMD: def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope): return weight class LoRA_Weight_decrement(LoRA_Weight_CMD): def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope): return weight * (1 - progress) class LoRA_Weight_increment(LoRA_Weight_CMD): def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope): return weight * progress def raise_(ex): raise ex def not_allow(name): return lambda: raise_(Exception(f'function {name} is not allow in LoRA Controller')) LoRA_Weight_eval_scope = { "abs": abs, "ceil": math.ceil, "floor": math.floor, "trunc": math.trunc, "fmod": math.fmod, "gcd": math.gcd, "lcm": math.lcm, "perm": math.perm, "comb": math.comb, "gamma": math.gamma, "sqrt": math.sqrt, "cbrt": lambda x: pow(x, 1.0 / 3.0), "exp": math.exp, "pow": math.pow, "log": math.log, "ln": math.log, "log2": math.log2, "log10": math.log10, "clamp": lambda x: 1.0 if x > 1 else (0.0 if x < 0 else x), "asin": lambda x: (math.acos(1.0 - x * 2.0) + 2.0 * math.pi) / (2.0 * math.pi), "acos": lambda x: (math.acos(x * 2.0 - 1.0) + 2.0 * math.pi) / (2.0 * math.pi), "atan": lambda x: (math.atan(x) + math.pi) / (2.0 * math.pi), "sin": lambda x: (math.sin(x * 2.0 * math.pi - (math.pi / 2.0)) + 1.0) / 2.0, "cos": lambda x: (math.sin(x * 2.0 * math.pi + (math.pi / 2.0)) + 1.0) / 2.0, "tan": lambda x: math.tan(x * 2.0 * math.pi), "sinr": math.sin, "cosr": math.cos, "tanr": math.tan, "asinr": math.asin, "acosr": math.acos, "atanr": math.atan, "sinh": math.sinh, "cosh": math.cosh, "tanh": math.tanh, "asinh": math.asinh, "acosh": math.acosh, "atanh": math.atanh, "abssin": lambda x: abs(math.sin(x * 2 * math.pi)), "abscos": lambda x: abs(math.cos(x * 2 * math.pi)), "random": random.random, "pi": math.pi, "nan": math.nan, "inf": math.inf, #not allow functions "eval": not_allow("eval"), "exec": not_allow("exec"), "compile": not_allow("compile"), "breakpoint": not_allow("breakpoint"), "__import__": not_allow("__import__") } class LoRA_Weight_eval(LoRA_Weight_CMD): def __init__(self, command : str, code_name : str = ""): self.command = command self.is_error = False from types import ModuleType self.module = ModuleType("module_in_prompt") self.module.__dict__.update(globals()) self.module.__dict__.update(LoRA_Weight_eval_scope) self.bin = Runable(self.command, code_name) def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope): result = None #setup local variables LoRA_Weight_eval_scope["enable_prepare_step"] = False LoRA_Weight_eval_scope["weight"] = weight LoRA_Weight_eval_scope["life"] = progress if step != -1 else 0 LoRA_Weight_eval_scope["step"] = step LoRA_Weight_eval_scope["steps"] = all_step LoRA_Weight_eval_scope["warmup"] = lambda x: progress / x if progress < x else 1.0 LoRA_Weight_eval_scope["cooldown"] = lambda x: (1 - progress) / (1 - x) if progress > x else 1.0 self.module.__dict__.update(globals()) self.module.__dict__.update(LoRA_Weight_eval_scope) self.module.__dict__.update(custom_scope) try: result = self.bin.run(self.module) try: result = float(result) * weight except Exception: raise Exception(\ f"LoRA Controller command result must be a numble, but got {type(result)}") if math.isnan(result): raise Exception(\ f"Can not apply a NaN weight to LoRA.") if math.isinf(result): raise Exception(\ f"Can not apply a infinity weight to LoRA.") except: if not self.is_error: print(f"CommandError: {self.command}") traceback.print_exception(*sys.exc_info()) self.is_error = True return weight if step == -1 and not self.module.__dict__["enable_prepare_step"]: return weight return result def __repr__(self): return f"LoRA_Weight_eval:{self.command}" def __str__(self): return f"LoRA_Weight_eval:{self.command}" class LoRA_Controller_Base: def __init__(self): self.base_weight = 1.0 self.Weight_Controller = LoRA_Weight_CMD() def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope): result = self.Weight_Controller.getWeight(weight, progress, step, all_step, custom_scope) if step == -1: if not isinstance(self.Weight_Controller, LoRA_Weight_eval): return weight return result def test(self, test_lora : str, step : int, all_step : int, custom_scope): return self.base_weight #normal lora class LoRA_Controller(LoRA_Controller_Base): def __init__(self, name : str, weight : float): super().__init__() self.name = name self.weight = float(weight) def test(self, test_lora : str, step : int, all_step : int, custom_scope): if test_lora == self.name: return self.getWeight(self.weight, float(step) / float(all_step), step, all_step, custom_scope) return 0.0 def __repr__(self): return f"LoRA_Controller:{self.name}[weight={self.weight}]" def __str__(self): return f"LoRA_Controller:{self.name}[weight={self.weight}]" #lora with start and end class LoRA_StartEnd_Controller(LoRA_Controller_Base): def __init__(self, name : str, weight : float, start : Union[float, int], end : Union[float, int]): super().__init__() self.name = name self.weight = float(weight) self.start = float(start) self.end = float(end) def test(self, test_lora : str, step : int, all_step : int, custom_scope): if test_lora == self.name: if step == -1: return self.getWeight(self.weight, -1, step, all_step, custom_scope) start = self.start end = self.end if start < 1: start = self.start * all_step if end < 1: end = self.end * all_step if end < 0: end = all_step if (step >= start) and (step <= end): return self.getWeight(self.weight, float(step - start) / float(end - start), step, all_step, custom_scope) return 0.0 def __repr__(self): return f"LoRA_StartEnd_Controller:{self.name}[weight={self.weight},start at={self.start},end at={self.end}]" def __str__(self): return f"LoRA_StartEnd_Controller:{self.name}[weight={self.weight},start at={self.start},end at={self.end}]" #switch lora class LoRA_Switcher_Controller(LoRA_Controller_Base): def __init__(self, lora_dist : List[LoRA_data], start : Union[float, int], end : Union[float, int]): super().__init__() self.lora_dist = lora_dist the_list : List[str] = [] self.lora_list = the_list self.start = float(start) self.end = float(end) for lora_item in self.lora_dist: self.lora_list.append(lora_item.name) def test(self, test_lora : str, step : int, all_step : int, custom_scope): lora_count = len(self.lora_dist) if step == -1 and test_lora in self.lora_list: return self.getWeight(self.lora_dist[self.lora_list.index(test_lora)].weight, -1, step, all_step, custom_scope) if test_lora == self.lora_list[step % lora_count]: start = self.start end = self.end if start < 1: start = self.start * all_step if end < 1: end = self.end * all_step if end < 0: end = all_step if (step >= start) and (step <= end): return self.getWeight(self.lora_dist[step % lora_count].weight, float(step - start) / float(end - start), step, all_step, custom_scope) return 0.0 def __repr__(self): return f"LoRA_Switcher_Controller:{self.lora_dist}[start at={self.start},end at={self.end}]" def __str__(self): return f"LoRA_Switcher_Controller:{self.lora_dist}[start at={self.start},end at={self.end}]" def parse_step_rendering_syntax(prompt: str): lora_controllers : List[List[LoRA_Controller_Base]] = [] subprompts = re_AND.split(escape_prompt(prompt)) for i, subprompt in enumerate(subprompts): tmp_lora_controllers: List[LoRA_Controller_Base] = [] step_rendering_list, pure_loratext = get_all_step_rendering_in_prompt(subprompt) for item in step_rendering_list: tmp_lora_controllers += get_LoRA_Controllers(item) lora_list = get_lora_list(pure_loratext) for lora_item in lora_list: tmp_lora_controllers.append(LoRA_Controller(lora_item.name, lora_item.weight)) lora_controllers.append(tmp_lora_controllers) return lora_controllers def check_lora_weight(controllers : List[LoRA_Controller_Base], test_lora : str, step : int, all_step : int, custom_scope): result_weight = 0.0 for controller in controllers: calc_weight = controller.test(test_lora, step, all_step, custom_scope) if abs(calc_weight) > abs(result_weight): result_weight = calc_weight return result_weight def get_lora_list(prompt: str): result : List[LoRA_data] = [] _, extra_network_data = extra_networks.parse_prompt(prompt) for m_type in ['lora', 'lyco']: if m_type in extra_network_data.keys(): for params in extra_network_data[m_type]: name = params.items[0] multiplier = float(params.items[1]) if len(params.items) > 1 else 1.0 result.append(LoRA_data(f"{m_type}:{name}", multiplier)) if len(result) <= 0: result.append(LoRA_data("", 0.0)) return result def get_or_list(prompt: str): return prompt.split("|") re_start_end = re.compile(r"\[\s*\[\s*([^\:\]]+)\:\s*\:([^\]]+)\]\s*\:\s*([^\]]+)\]") re_strat_at = re.compile(r"\[\s*([^\:\]]+)\:\s*([0-9\.]+)\s*\]") re_bucket_inside = re.compile(r"\[([^\]]+)\]") re_extra_net = re.compile(r"<([^>]+):([^>]+)>") re_python_escape = re.compile(r"\$\$PYTHON_OBJ\$\$(\d+)\^") re_python_escape_x = re.compile(r"\$\$PYTHON_OBJX?\$\$(\d+)\^") re_sd_step_render = re.compile(r"\[[^\[\]]+\]") re_super_cmd = re.compile(r"(\\u0023|#)([^:#\[\]]+)") re_escape_char = re.compile(r"\\([\[\]\:\\])") def escape_prompt(prompt : str): def preprossing_escape(match_pt : re.Match): input_str = str(match_pt.group(1)) if input_str == '[': return '\\u005B' elif input_str == ']': return '\\u005D' elif input_str == ':': return '\\u003A' elif input_str == '\\': return '\\u005C' return str(match_pt.group(0)) return re.sub(re_escape_char, preprossing_escape, prompt) class MySearchResult: def __init__(self): group : List[str] = [] self.group = group def extra_net_split(input_str : str, pattern : str): result : List[str] = [] extra_net_list : List[str] = [] escape_obj_list : List[str] = [] def preprossing_escape(match_pt : re.Match): escape_obj_list.append(str(match_pt.group(0))) return f"$$PYTHON_OBJX$${len(escape_obj_list)-1}^" def preprossing_extra_net(match_pt : re.Match): extra_net_list.append(str(match_pt.group(0))) return f"$$PYTHON_OBJ$${len(extra_net_list)-1}^" def unstrip_extra_net_pattern(match_pt : re.Match): input_str = str(match_pt.group(0)) try: index = int(match_pt.group(1)) return extra_net_list[index] except Exception: return input_str def unstrip_text_pattern_obj(match_pt : re.Match): input_str = str(match_pt.group(0)) try: index = int(match_pt.group(1)) return escape_obj_list[index] except Exception: return input_str txt : str = input_str txt = re.sub(re_python_escape_x, preprossing_escape, txt) txt = re.sub(re_extra_net, preprossing_extra_net, txt) pre_result = txt.split(pattern) for i in range(len(pre_result)): try: cur_pattern = str(pre_result[i]) cur_result = re.sub(re_python_escape, unstrip_extra_net_pattern, cur_pattern) cur_result = re.sub(re_python_escape_x, unstrip_text_pattern_obj, cur_result) result.append(cur_result) except Exception as ex: break if len(result) <= 0: return [input_str] return result def extra_net_re_search(pattern : Union[str, re.Pattern[str]], input_str : str): result = MySearchResult() extra_net_list : List[str] = [] escape_obj_list : List[str] = [] def preprossing_escape(match_pt : re.Match): escape_obj_list.append(str(match_pt.group(0))) return f"$$PYTHON_OBJX$${len(escape_obj_list)-1}^" def preprossing_extra_net(match_pt : re.Match): extra_net_list.append(str(match_pt.group(0))) return f"$$PYTHON_OBJ$${len(extra_net_list)-1}^" def unstrip_extra_net_pattern(match_pt : re.Match): input_str = str(match_pt.group(0)) try: index = int(match_pt.group(1)) return extra_net_list[index] except Exception: return input_str def unstrip_text_pattern_obj(match_pt : re.Match): input_str = str(match_pt.group(0)) try: index = int(match_pt.group(1)) return escape_obj_list[index] except Exception: return input_str txt : str = input_str txt = re.sub(re_python_escape_x, preprossing_escape, txt) txt = re.sub(re_extra_net, preprossing_extra_net, txt) pre_result = re.search(pattern, txt) for i in range(1000): try: cur_pattern = str(pre_result.group(i)) cur_result = re.sub(re_python_escape, unstrip_extra_net_pattern, cur_pattern) cur_result = re.sub(re_python_escape_x, unstrip_text_pattern_obj, cur_result) result.group.append(cur_result) except Exception as ex: break if len(result.group) <= 0: return None return result def unescape_string(input_string : str): result = '' unicode_list = ['u','x'] i = 0 #for(var i=0; i= len(input_string): break string_body = input_string[i] if(string_body.lower() in unicode_list): result += f"{current_char}{string_body}" else: char_added = False try: unescaped = json.loads(f"\"{current_char}{string_body}\"") if unescaped: result += unescaped char_added = True except Exception: pass if not char_added: result += string_body else: result += current_char i += 1 return str(json.loads(json.dumps(result, indent=4).replace("\\\\", "\\"))) def get_LoRA_Controllers(prompt: str): result = extra_net_re_search(re_start_end, prompt) super_cmd = re.search(re_super_cmd, prompt) Weight_Controller = LoRA_Weight_CMD() if super_cmd: super_cmd_text = unescape_string(super_cmd.group(2)).strip() if super_cmd_text.startswith("cmd("): Weight_Controller = LoRA_Weight_eval(super_cmd_text[4:-1], f", at {re.sub(re_super_cmd, '', prompt)}") elif super_cmd_text.startswith("decrease"): Weight_Controller = LoRA_Weight_decrement() elif super_cmd_text.startswith("increment"): Weight_Controller = LoRA_Weight_increment() def set_Weight_Controller(controller_list : list[LoRA_Controller_Base], the_controller : LoRA_Weight_CMD): for i, the_item in enumerate(controller_list): controller_list[i].Weight_Controller = the_controller return controller_list result_list: List[LoRA_Controller_Base] = [] if result: or_list = get_or_list(result.group[1]) if len(or_list) == 1: #LoRA with start and end lora_list = get_lora_list(or_list[0]) for lora_item in lora_list: try: result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, float(result.group[3]), float(result.group[2]))) except Exception: continue return set_Weight_Controller(result_list, Weight_Controller) lora_lists : List[List[LoRA_data]] = [] max_len = -1 for or_block in or_list: #or lora_list = get_lora_list(or_block) lora_list_len = len(lora_list) if lora_list_len > max_len: max_len = lora_list_len lora_lists.append(lora_list) if max_len > 0: for i in range(max_len): tmp_lora_list : List[LoRA_data] = [] for it_lora_list in lora_lists: tmp_lora = LoRA_data("", 0.0) if i < len(it_lora_list): tmp_lora = it_lora_list[i] tmp_lora_list.append(tmp_lora) result_list.append(LoRA_Switcher_Controller(tmp_lora_list, float(result.group[3]), float(result.group[2]))) return set_Weight_Controller(result_list, Weight_Controller) result = extra_net_re_search(re_strat_at, prompt) if result: or_list = get_or_list(result.group[1]) if len(or_list) == 1: #LoRA with start and end lora_list = get_lora_list(or_list[0]) for lora_item in lora_list: try: result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, float(result.group[2]), -1.0)) except Exception: continue return set_Weight_Controller(result_list, Weight_Controller) lora_lists : List[List[LoRA_data]] = [] max_len = -1 for or_block in or_list: #or lora_list = get_lora_list(or_block) lora_list_len = len(lora_list) if lora_list_len > max_len: max_len = lora_list_len lora_lists.append(lora_list) if max_len > 0: for i in range(max_len): tmp_lora_list : List[LoRA_data] = [] for it_lora_list in lora_lists: tmp_lora = LoRA_data("", 0.0) if i < len(it_lora_list): tmp_lora = it_lora_list[i] tmp_lora_list.append(tmp_lora) result_list.append(LoRA_Switcher_Controller(tmp_lora_list, float(result.group[2]), -1.0)) return set_Weight_Controller(result_list, Weight_Controller) result = extra_net_re_search(re_bucket_inside, prompt) if result: bucket_inside = result.group[1] split_by_colon = extra_net_split(bucket_inside,":") if len(split_by_colon) == 1 and (("|" in bucket_inside) or ("#" in bucket_inside)): split_by_colon.append('') split_by_colon.append('-1') if len(split_by_colon) > 2: should_pass = False or_list = get_or_list(split_by_colon[0]) if len(or_list) == 1: #LoRA with start and end lora_list = get_lora_list(or_list[0]) for lora_item in lora_list: try: result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, 0.0, float(split_by_colon[2]))) except Exception: continue should_pass = True if not should_pass: lora_lists : List[List[LoRA_data]] = [] max_len = -1 for or_block in or_list: #or lora_list = get_lora_list(or_block) lora_list_len = len(lora_list) if lora_list_len > max_len: max_len = lora_list_len lora_lists.append(lora_list) if max_len > 0: for i in range(max_len): tmp_lora_list : List[LoRA_data] = [] for it_lora_list in lora_lists: tmp_lora = LoRA_data("", 0.0) if i < len(it_lora_list): tmp_lora = it_lora_list[i] tmp_lora_list.append(tmp_lora) result_list.append(LoRA_Switcher_Controller(tmp_lora_list, 0.0, float(split_by_colon[2]))) should_pass = False or_list = get_or_list(split_by_colon[1]) if len(or_list) == 1: #LoRA with start and end lora_list = get_lora_list(or_list[0]) for lora_item in lora_list: try: result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, float(split_by_colon[2]), -1.0)) except Exception: continue should_pass = True if not should_pass: lora_lists : List[List[LoRA_data]] = [] max_len = -1 for or_block in or_list: #or lora_list = get_lora_list(or_block) lora_list_len = len(lora_list) if lora_list_len > max_len: max_len = lora_list_len lora_lists.append(lora_list) if max_len > 0: for i in range(max_len): tmp_lora_list : List[LoRA_data] = [] for it_lora_list in lora_lists: tmp_lora = LoRA_data("", 0.0) if i < len(it_lora_list): tmp_lora = it_lora_list[i] tmp_lora_list.append(tmp_lora) result_list.append(LoRA_Switcher_Controller(tmp_lora_list, float(split_by_colon[2]), -1.0)) return set_Weight_Controller(result_list, Weight_Controller) return set_Weight_Controller(result_list, Weight_Controller) def get_all_step_rendering_in_prompt(input_prompt : str): read_rendering_item_list : List[str] = [] escape_obj_list : List[str] = [] rendering_item_list : List[str] = [] def preprossing_step_rendering_item(match_pt : re.Match): read_rendering_item_list.append(str(match_pt.group(0))) return f"$$PYTHON_OBJ$${len(read_rendering_item_list)-1}^" def preprossing_step_rendering_text(match_pt : re.Match): escape_obj_list.append(str(match_pt.group(0))) return f"$$PYTHON_OBJX$${len(escape_obj_list)-1}^" def load_step_rendering_item(match_pt : re.Match): input_str = str(match_pt.group(0)) rendering_item_list.append(input_str) return input_str def unstrip_rendering_text_pattern(match_pt : re.Match): input_str = str(match_pt.group(0)) try: index = int(match_pt.group(1)) return read_rendering_item_list[index] except Exception: return input_str def unstrip_rendering_text_pattern_obj(match_pt : re.Match): input_str = str(match_pt.group(0)) try: index = int(match_pt.group(1)) return escape_obj_list[index] except Exception: return input_str def unstrip_rendering_text(input_str : str): old_result : str = "None" result : str = input_str while old_result != result: old_result = result result = re.sub(re_python_escape, unstrip_rendering_text_pattern, result) old_result = "None" while old_result != result: old_result = result result = re.sub(re_python_escape_x, unstrip_rendering_text_pattern_obj, result) return result txt : str = input_prompt txt = re.sub(re_python_escape_x, preprossing_step_rendering_text, txt) old_txt : str = "None" while old_txt != txt: old_txt = txt txt = re.sub(re_sd_step_render, preprossing_step_rendering_item, txt) re.sub(re_python_escape, load_step_rendering_item, txt) for i, the_item in enumerate(rendering_item_list): rendering_item_list[i] = unstrip_rendering_text(the_item) return rendering_item_list, txt