stable-diffusion-webui-comp.../composable_lora_step.py

632 lines
26 KiB
Python

from typing import List, Union
import re
import ast
import copy
import json
import math
import sys
import traceback
import random
from modules import extra_networks
re_AND = re.compile(r"\bAND\b")
class Runable:
"""
like exec() but can return values
https://stackoverflow.com/a/52361938/5862977
"""
def __init__(self, code : str, code_name : str = "<prompt>"):
self.code = code
self.code_name = code_name
self.compiled = False
try:
self.compile_self()
except Exception:
pass
def compile_self(self):
self.code_ast = ast.parse(self.code, self.code_name)
self.init_ast = copy.deepcopy(self.code_ast)
self.init_ast.body = self.code_ast.body[:-1]
self.last_ast = copy.deepcopy(self.code_ast)
self.last_ast.body = self.code_ast.body[-1:]
self.full_bin = compile(self.code_ast, self.code_name, "exec")
self.start_bin = compile(self.init_ast, self.code_name, "exec")
if type(self.last_ast.body[0]) == ast.Expr:
self.run_bin = compile(self.convertExpr2Expression(self.last_ast.body[0]), self.code_name, "eval")
else:
self.end_bin = compile(self.last_ast, self.code_name, "exec")
self.compiled = True
def convertExpr2Expression(self, expr : ast.Expr):
expr.lineno = 0
expr.col_offset = 0
result = ast.Expression(expr.value, lineno=0, col_offset = 0)
return result
def run(self, module):
if not self.compiled:
self.compile_self()
if len(self.init_ast.body) > 0:
exec(self.start_bin, module.__dict__)
if type(self.last_ast.body[0]) == ast.Expr:
return eval(self.run_bin, module.__dict__)
else:
exec(self.end_bin, module.__dict__)
class LoRA_data:
def __init__(self, name : str, weight : float):
self.name = name
self.weight = weight
def __repr__(self):
return f"LoRA:{self.name}:{self.weight}"
def __str__(self):
return f"LoRA:{self.name}:{self.weight}"
class LoRA_Weight_CMD:
def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope):
return weight
class LoRA_Weight_decrement(LoRA_Weight_CMD):
def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope):
return weight * (1 - progress)
class LoRA_Weight_increment(LoRA_Weight_CMD):
def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope):
return weight * progress
def raise_(ex):
raise ex
def not_allow(name):
return lambda: raise_(Exception(f'function {name} is not allow in LoRA Controller'))
LoRA_Weight_eval_scope = {
"abs": abs,
"ceil": math.ceil, "floor": math.floor, "trunc": math.trunc,
"fmod": math.fmod,
"gcd": math.gcd, "lcm": math.lcm,
"perm": math.perm, "comb": math.comb, "gamma": math.gamma,
"sqrt": math.sqrt, "cbrt": lambda x: pow(x, 1.0 / 3.0),
"exp": math.exp, "pow": math.pow,
"log": math.log, "ln": math.log, "log2": math.log2, "log10": math.log10,
"clamp": lambda x: 1.0 if x > 1 else (0.0 if x < 0 else x),
"asin": lambda x: (math.acos(1.0 - x * 2.0) + 2.0 * math.pi) / (2.0 * math.pi),
"acos": lambda x: (math.acos(x * 2.0 - 1.0) + 2.0 * math.pi) / (2.0 * math.pi),
"atan": lambda x: (math.atan(x) + math.pi) / (2.0 * math.pi),
"sin": lambda x: (math.sin(x * 2.0 * math.pi - (math.pi / 2.0)) + 1.0) / 2.0,
"cos": lambda x: (math.sin(x * 2.0 * math.pi + (math.pi / 2.0)) + 1.0) / 2.0,
"tan": lambda x: math.tan(x * 2.0 * math.pi),
"sinr": math.sin, "cosr": math.cos, "tanr": math.tan,
"asinr": math.asin, "acosr": math.acos, "atanr": math.atan,
"sinh": math.sinh, "cosh": math.cosh, "tanh": math.tanh,
"asinh": math.asinh, "acosh": math.acosh, "atanh": math.atanh,
"abssin": lambda x: abs(math.sin(x * 2 * math.pi)),
"abscos": lambda x: abs(math.cos(x * 2 * math.pi)),
"random": random.random,
"pi": math.pi, "nan": math.nan, "inf": math.inf,
#not allow functions
"eval": not_allow("eval"),
"exec": not_allow("exec"),
"compile": not_allow("compile"),
"breakpoint": not_allow("breakpoint"),
"__import__": not_allow("__import__")
}
class LoRA_Weight_eval(LoRA_Weight_CMD):
def __init__(self, command : str, code_name : str = "<prompt>"):
self.command = command
self.is_error = False
from types import ModuleType
self.module = ModuleType("module_in_prompt")
self.module.__dict__.update(globals())
self.module.__dict__.update(LoRA_Weight_eval_scope)
self.bin = Runable(self.command, code_name)
def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope):
result = None
#setup local variables
LoRA_Weight_eval_scope["enable_prepare_step"] = False
LoRA_Weight_eval_scope["weight"] = weight
LoRA_Weight_eval_scope["life"] = progress if step != -1 else 0
LoRA_Weight_eval_scope["step"] = step
LoRA_Weight_eval_scope["steps"] = all_step
LoRA_Weight_eval_scope["warmup"] = lambda x: progress / x if progress < x else 1.0
LoRA_Weight_eval_scope["cooldown"] = lambda x: (1 - progress) / (1 - x) if progress > x else 1.0
self.module.__dict__.update(globals())
self.module.__dict__.update(LoRA_Weight_eval_scope)
self.module.__dict__.update(custom_scope)
try:
result = self.bin.run(self.module)
try:
result = float(result) * weight
except Exception:
raise Exception(\
f"LoRA Controller command result must be a numble, but got {type(result)}")
if math.isnan(result):
raise Exception(\
f"Can not apply a NaN weight to LoRA.")
if math.isinf(result):
raise Exception(\
f"Can not apply a infinity weight to LoRA.")
except:
if not self.is_error:
print(f"CommandError: {self.command}")
traceback.print_exception(*sys.exc_info())
self.is_error = True
return weight
if step == -1 and not self.module.__dict__["enable_prepare_step"]:
return weight
return result
def __repr__(self):
return f"LoRA_Weight_eval:{self.command}"
def __str__(self):
return f"LoRA_Weight_eval:{self.command}"
class LoRA_Controller_Base:
def __init__(self):
self.base_weight = 1.0
self.Weight_Controller = LoRA_Weight_CMD()
def getWeight(self, weight : float, progress: float, step : int, all_step : int, custom_scope):
result = self.Weight_Controller.getWeight(weight, progress, step, all_step, custom_scope)
if step == -1:
if not isinstance(self.Weight_Controller, LoRA_Weight_eval):
return weight
return result
def test(self, test_lora : str, step : int, all_step : int, custom_scope):
return self.base_weight
#normal lora
class LoRA_Controller(LoRA_Controller_Base):
def __init__(self, name : str, weight : float):
super().__init__()
self.name = name
self.weight = float(weight)
def test(self, test_lora : str, step : int, all_step : int, custom_scope):
if test_lora == self.name:
return self.getWeight(self.weight, float(step) / float(all_step), step, all_step, custom_scope)
return 0.0
def __repr__(self):
return f"LoRA_Controller:{self.name}[weight={self.weight}]"
def __str__(self):
return f"LoRA_Controller:{self.name}[weight={self.weight}]"
#lora with start and end
class LoRA_StartEnd_Controller(LoRA_Controller_Base):
def __init__(self, name : str, weight : float, start : Union[float, int], end : Union[float, int]):
super().__init__()
self.name = name
self.weight = float(weight)
self.start = float(start)
self.end = float(end)
def test(self, test_lora : str, step : int, all_step : int, custom_scope):
if test_lora == self.name:
if step == -1:
return self.getWeight(self.weight, -1, step, all_step, custom_scope)
start = self.start
end = self.end
if start < 1:
start = self.start * all_step
if end < 1:
end = self.end * all_step
if end < 0:
end = all_step
if (step >= start) and (step <= end):
return self.getWeight(self.weight, float(step - start) / float(end - start), step, all_step, custom_scope)
return 0.0
def __repr__(self):
return f"LoRA_StartEnd_Controller:{self.name}[weight={self.weight},start at={self.start},end at={self.end}]"
def __str__(self):
return f"LoRA_StartEnd_Controller:{self.name}[weight={self.weight},start at={self.start},end at={self.end}]"
#switch lora
class LoRA_Switcher_Controller(LoRA_Controller_Base):
def __init__(self, lora_dist : List[LoRA_data], start : Union[float, int], end : Union[float, int]):
super().__init__()
self.lora_dist = lora_dist
the_list : List[str] = []
self.lora_list = the_list
self.start = float(start)
self.end = float(end)
for lora_item in self.lora_dist:
self.lora_list.append(lora_item.name)
def test(self, test_lora : str, step : int, all_step : int, custom_scope):
lora_count = len(self.lora_dist)
if step == -1 and test_lora in self.lora_list:
return self.getWeight(self.lora_dist[self.lora_list.index(test_lora)].weight, -1, step, all_step, custom_scope)
if test_lora == self.lora_list[step % lora_count]:
start = self.start
end = self.end
if start < 1:
start = self.start * all_step
if end < 1:
end = self.end * all_step
if end < 0:
end = all_step
if (step >= start) and (step <= end):
return self.getWeight(self.lora_dist[step % lora_count].weight, float(step - start) / float(end - start), step, all_step, custom_scope)
return 0.0
def __repr__(self):
return f"LoRA_Switcher_Controller:{self.lora_dist}[start at={self.start},end at={self.end}]"
def __str__(self):
return f"LoRA_Switcher_Controller:{self.lora_dist}[start at={self.start},end at={self.end}]"
def parse_step_rendering_syntax(prompt: str):
lora_controllers : List[List[LoRA_Controller_Base]] = []
subprompts = re_AND.split(escape_prompt(prompt))
for i, subprompt in enumerate(subprompts):
tmp_lora_controllers: List[LoRA_Controller_Base] = []
step_rendering_list, pure_loratext = get_all_step_rendering_in_prompt(subprompt)
for item in step_rendering_list:
tmp_lora_controllers += get_LoRA_Controllers(item)
lora_list = get_lora_list(pure_loratext)
for lora_item in lora_list:
tmp_lora_controllers.append(LoRA_Controller(lora_item.name, lora_item.weight))
lora_controllers.append(tmp_lora_controllers)
return lora_controllers
def check_lora_weight(controllers : List[LoRA_Controller_Base], test_lora : str, step : int, all_step : int, custom_scope):
result_weight = 0.0
for controller in controllers:
calc_weight = controller.test(test_lora, step, all_step, custom_scope)
if abs(calc_weight) > abs(result_weight):
result_weight = calc_weight
return result_weight
def get_lora_list(prompt: str):
result : List[LoRA_data] = []
_, extra_network_data = extra_networks.parse_prompt(prompt)
for m_type in ['lora', 'lyco']:
if m_type in extra_network_data.keys():
for params in extra_network_data[m_type]:
name = params.items[0]
multiplier = float(params.items[1]) if len(params.items) > 1 else 1.0
result.append(LoRA_data(f"{m_type}:{name}", multiplier))
if len(result) <= 0:
result.append(LoRA_data("", 0.0))
return result
def get_or_list(prompt: str):
return prompt.split("|")
re_start_end = re.compile(r"\[\s*\[\s*([^\:\]]+)\:\s*\:([^\]]+)\]\s*\:\s*([^\]]+)\]")
re_strat_at = re.compile(r"\[\s*([^\:\]]+)\:\s*([0-9\.]+)\s*\]")
re_bucket_inside = re.compile(r"\[([^\]]+)\]")
re_extra_net = re.compile(r"<([^>]+):([^>]+)>")
re_python_escape = re.compile(r"\$\$PYTHON_OBJ\$\$(\d+)\^")
re_python_escape_x = re.compile(r"\$\$PYTHON_OBJX?\$\$(\d+)\^")
re_sd_step_render = re.compile(r"\[[^\[\]]+\]")
re_super_cmd = re.compile(r"(\\u0023|#)([^:#\[\]]+)")
re_escape_char = re.compile(r"\\([\[\]\:\\])")
def escape_prompt(prompt : str):
def preprossing_escape(match_pt : re.Match):
input_str = str(match_pt.group(1))
if input_str == '[':
return '\\u005B'
elif input_str == ']':
return '\\u005D'
elif input_str == ':':
return '\\u003A'
elif input_str == '\\':
return '\\u005C'
return str(match_pt.group(0))
return re.sub(re_escape_char, preprossing_escape, prompt)
class MySearchResult:
def __init__(self):
group : List[str] = []
self.group = group
def extra_net_split(input_str : str, pattern : str):
result : List[str] = []
extra_net_list : List[str] = []
escape_obj_list : List[str] = []
def preprossing_escape(match_pt : re.Match):
escape_obj_list.append(str(match_pt.group(0)))
return f"$$PYTHON_OBJX$${len(escape_obj_list)-1}^"
def preprossing_extra_net(match_pt : re.Match):
extra_net_list.append(str(match_pt.group(0)))
return f"$$PYTHON_OBJ$${len(extra_net_list)-1}^"
def unstrip_extra_net_pattern(match_pt : re.Match):
input_str = str(match_pt.group(0))
try:
index = int(match_pt.group(1))
return extra_net_list[index]
except Exception:
return input_str
def unstrip_text_pattern_obj(match_pt : re.Match):
input_str = str(match_pt.group(0))
try:
index = int(match_pt.group(1))
return escape_obj_list[index]
except Exception:
return input_str
txt : str = input_str
txt = re.sub(re_python_escape_x, preprossing_escape, txt)
txt = re.sub(re_extra_net, preprossing_extra_net, txt)
pre_result = txt.split(pattern)
for i in range(len(pre_result)):
try:
cur_pattern = str(pre_result[i])
cur_result = re.sub(re_python_escape, unstrip_extra_net_pattern, cur_pattern)
cur_result = re.sub(re_python_escape_x, unstrip_text_pattern_obj, cur_result)
result.append(cur_result)
except Exception as ex:
break
if len(result) <= 0:
return [input_str]
return result
def extra_net_re_search(pattern : Union[str, re.Pattern[str]], input_str : str):
result = MySearchResult()
extra_net_list : List[str] = []
escape_obj_list : List[str] = []
def preprossing_escape(match_pt : re.Match):
escape_obj_list.append(str(match_pt.group(0)))
return f"$$PYTHON_OBJX$${len(escape_obj_list)-1}^"
def preprossing_extra_net(match_pt : re.Match):
extra_net_list.append(str(match_pt.group(0)))
return f"$$PYTHON_OBJ$${len(extra_net_list)-1}^"
def unstrip_extra_net_pattern(match_pt : re.Match):
input_str = str(match_pt.group(0))
try:
index = int(match_pt.group(1))
return extra_net_list[index]
except Exception:
return input_str
def unstrip_text_pattern_obj(match_pt : re.Match):
input_str = str(match_pt.group(0))
try:
index = int(match_pt.group(1))
return escape_obj_list[index]
except Exception:
return input_str
txt : str = input_str
txt = re.sub(re_python_escape_x, preprossing_escape, txt)
txt = re.sub(re_extra_net, preprossing_extra_net, txt)
pre_result = re.search(pattern, txt)
for i in range(1000):
try:
cur_pattern = str(pre_result.group(i))
cur_result = re.sub(re_python_escape, unstrip_extra_net_pattern, cur_pattern)
cur_result = re.sub(re_python_escape_x, unstrip_text_pattern_obj, cur_result)
result.group.append(cur_result)
except Exception as ex:
break
if len(result.group) <= 0:
return None
return result
def unescape_string(input_string : str):
result = ''
unicode_list = ['u','x']
i = 0 #for(var i=0; i<input_string.length; ++i)
while i < len(input_string):
current_char = input_string[i]
if current_char == '\\':
i += 1
if i >= len(input_string):
break
string_body = input_string[i]
if(string_body.lower() in unicode_list):
result += f"{current_char}{string_body}"
else:
char_added = False
try:
unescaped = json.loads(f"\"{current_char}{string_body}\"")
if unescaped:
result += unescaped
char_added = True
except Exception:
pass
if not char_added:
result += string_body
else:
result += current_char
i += 1
return str(json.loads(json.dumps(result, indent=4).replace("\\\\", "\\")))
def get_LoRA_Controllers(prompt: str):
result = extra_net_re_search(re_start_end, prompt)
super_cmd = re.search(re_super_cmd, prompt)
Weight_Controller = LoRA_Weight_CMD()
if super_cmd:
super_cmd_text = unescape_string(super_cmd.group(2)).strip()
if super_cmd_text.startswith("cmd("):
Weight_Controller = LoRA_Weight_eval(super_cmd_text[4:-1], f"<prompt>, at {re.sub(re_super_cmd, '', prompt)}")
elif super_cmd_text.startswith("decrease"):
Weight_Controller = LoRA_Weight_decrement()
elif super_cmd_text.startswith("increment"):
Weight_Controller = LoRA_Weight_increment()
def set_Weight_Controller(controller_list : list[LoRA_Controller_Base], the_controller : LoRA_Weight_CMD):
for i, the_item in enumerate(controller_list):
controller_list[i].Weight_Controller = the_controller
return controller_list
result_list: List[LoRA_Controller_Base] = []
if result:
or_list = get_or_list(result.group[1])
if len(or_list) == 1: #LoRA with start and end
lora_list = get_lora_list(or_list[0])
for lora_item in lora_list:
try:
result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, float(result.group[3]), float(result.group[2])))
except Exception:
continue
return set_Weight_Controller(result_list, Weight_Controller)
lora_lists : List[List[LoRA_data]] = []
max_len = -1
for or_block in or_list: #or
lora_list = get_lora_list(or_block)
lora_list_len = len(lora_list)
if lora_list_len > max_len:
max_len = lora_list_len
lora_lists.append(lora_list)
if max_len > 0:
for i in range(max_len):
tmp_lora_list : List[LoRA_data] = []
for it_lora_list in lora_lists:
tmp_lora = LoRA_data("", 0.0)
if i < len(it_lora_list):
tmp_lora = it_lora_list[i]
tmp_lora_list.append(tmp_lora)
result_list.append(LoRA_Switcher_Controller(tmp_lora_list, float(result.group[3]), float(result.group[2])))
return set_Weight_Controller(result_list, Weight_Controller)
result = extra_net_re_search(re_strat_at, prompt)
if result:
or_list = get_or_list(result.group[1])
if len(or_list) == 1: #LoRA with start and end
lora_list = get_lora_list(or_list[0])
for lora_item in lora_list:
try:
result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, float(result.group[2]), -1.0))
except Exception:
continue
return set_Weight_Controller(result_list, Weight_Controller)
lora_lists : List[List[LoRA_data]] = []
max_len = -1
for or_block in or_list: #or
lora_list = get_lora_list(or_block)
lora_list_len = len(lora_list)
if lora_list_len > max_len:
max_len = lora_list_len
lora_lists.append(lora_list)
if max_len > 0:
for i in range(max_len):
tmp_lora_list : List[LoRA_data] = []
for it_lora_list in lora_lists:
tmp_lora = LoRA_data("", 0.0)
if i < len(it_lora_list):
tmp_lora = it_lora_list[i]
tmp_lora_list.append(tmp_lora)
result_list.append(LoRA_Switcher_Controller(tmp_lora_list, float(result.group[2]), -1.0))
return set_Weight_Controller(result_list, Weight_Controller)
result = extra_net_re_search(re_bucket_inside, prompt)
if result:
bucket_inside = result.group[1]
split_by_colon = extra_net_split(bucket_inside,":")
if len(split_by_colon) == 1 and (("|" in bucket_inside) or ("#" in bucket_inside)):
split_by_colon.append('')
split_by_colon.append('-1')
if len(split_by_colon) > 2:
should_pass = False
or_list = get_or_list(split_by_colon[0])
if len(or_list) == 1: #LoRA with start and end
lora_list = get_lora_list(or_list[0])
for lora_item in lora_list:
try:
result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, 0.0, float(split_by_colon[2])))
except Exception:
continue
should_pass = True
if not should_pass:
lora_lists : List[List[LoRA_data]] = []
max_len = -1
for or_block in or_list: #or
lora_list = get_lora_list(or_block)
lora_list_len = len(lora_list)
if lora_list_len > max_len:
max_len = lora_list_len
lora_lists.append(lora_list)
if max_len > 0:
for i in range(max_len):
tmp_lora_list : List[LoRA_data] = []
for it_lora_list in lora_lists:
tmp_lora = LoRA_data("", 0.0)
if i < len(it_lora_list):
tmp_lora = it_lora_list[i]
tmp_lora_list.append(tmp_lora)
result_list.append(LoRA_Switcher_Controller(tmp_lora_list, 0.0, float(split_by_colon[2])))
should_pass = False
or_list = get_or_list(split_by_colon[1])
if len(or_list) == 1: #LoRA with start and end
lora_list = get_lora_list(or_list[0])
for lora_item in lora_list:
try:
result_list.append(LoRA_StartEnd_Controller(lora_item.name, lora_item.weight, float(split_by_colon[2]), -1.0))
except Exception:
continue
should_pass = True
if not should_pass:
lora_lists : List[List[LoRA_data]] = []
max_len = -1
for or_block in or_list: #or
lora_list = get_lora_list(or_block)
lora_list_len = len(lora_list)
if lora_list_len > max_len:
max_len = lora_list_len
lora_lists.append(lora_list)
if max_len > 0:
for i in range(max_len):
tmp_lora_list : List[LoRA_data] = []
for it_lora_list in lora_lists:
tmp_lora = LoRA_data("", 0.0)
if i < len(it_lora_list):
tmp_lora = it_lora_list[i]
tmp_lora_list.append(tmp_lora)
result_list.append(LoRA_Switcher_Controller(tmp_lora_list, float(split_by_colon[2]), -1.0))
return set_Weight_Controller(result_list, Weight_Controller)
return set_Weight_Controller(result_list, Weight_Controller)
def get_all_step_rendering_in_prompt(input_prompt : str):
read_rendering_item_list : List[str] = []
escape_obj_list : List[str] = []
rendering_item_list : List[str] = []
def preprossing_step_rendering_item(match_pt : re.Match):
read_rendering_item_list.append(str(match_pt.group(0)))
return f"$$PYTHON_OBJ$${len(read_rendering_item_list)-1}^"
def preprossing_step_rendering_text(match_pt : re.Match):
escape_obj_list.append(str(match_pt.group(0)))
return f"$$PYTHON_OBJX$${len(escape_obj_list)-1}^"
def load_step_rendering_item(match_pt : re.Match):
input_str = str(match_pt.group(0))
rendering_item_list.append(input_str)
return input_str
def unstrip_rendering_text_pattern(match_pt : re.Match):
input_str = str(match_pt.group(0))
try:
index = int(match_pt.group(1))
return read_rendering_item_list[index]
except Exception:
return input_str
def unstrip_rendering_text_pattern_obj(match_pt : re.Match):
input_str = str(match_pt.group(0))
try:
index = int(match_pt.group(1))
return escape_obj_list[index]
except Exception:
return input_str
def unstrip_rendering_text(input_str : str):
old_result : str = "None"
result : str = input_str
while old_result != result:
old_result = result
result = re.sub(re_python_escape, unstrip_rendering_text_pattern, result)
old_result = "None"
while old_result != result:
old_result = result
result = re.sub(re_python_escape_x, unstrip_rendering_text_pattern_obj, result)
return result
txt : str = input_prompt
txt = re.sub(re_python_escape_x, preprossing_step_rendering_text, txt)
old_txt : str = "None"
while old_txt != txt:
old_txt = txt
txt = re.sub(re_sd_step_render, preprossing_step_rendering_item, txt)
re.sub(re_python_escape, load_step_rendering_item, txt)
for i, the_item in enumerate(rendering_item_list):
rendering_item_list[i] = unstrip_rendering_text(the_item)
return rendering_item_list, txt