diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aa902a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +scripts/__pycache__/prompts-filter.cpython-310.pyc +scripts/__pycache__/blocked-words-filter.cpython-310.pyc diff --git a/scripts/blocked-words-filter.py b/scripts/prompts-filter.py similarity index 72% rename from scripts/blocked-words-filter.py rename to scripts/prompts-filter.py index 9bb5c2a..b81dcc8 100644 --- a/scripts/blocked-words-filter.py +++ b/scripts/prompts-filter.py @@ -2,26 +2,11 @@ import re from pathlib import Path from typing import List -from modules import scripts,shared,script_callbacks +from modules import script_callbacks, scripts, shared from modules.paths_internal import data_path + DATA_PATH = Path(data_path) -blocked_prompts_txt_file = str(DATA_PATH.joinpath('blocked_prompts.txt')) -blocked_negative_prompts_txt_file = str(DATA_PATH.joinpath('blocked_negative_prompts.txt')) - -def setVal(): - global blocked_prompts_txt_file - global blocked_negative_prompts_txt_file - global blocked_prompts - global blocked_negative_prompts - - blocked_prompts_txt_file = shared.opts.data.get('blocked_prompts_txt_file',blocked_prompts_txt_file) - blocked_negative_prompts_txt_file = shared.opts.data.get('blocked_negative_prompts_txt_file',blocked_negative_prompts_txt_file) - blocked_prompts=get_prompts_by_file(Path(blocked_prompts_txt_file)) - blocked_negative_prompts=get_prompts_by_file(Path(blocked_negative_prompts_txt_file)) - -splitSign = [',','(',')','[',']','{','}',':','>'] - def get_prompts_by_file(path:Path): if path.exists(): with path.open('r') as f: @@ -30,11 +15,33 @@ def get_prompts_by_file(path:Path): else: return [] +blocked_prompts_txt_file = str(DATA_PATH.joinpath('blocked_prompts.txt')) +blocked_negative_prompts_txt_file = str(DATA_PATH.joinpath('blocked_negative_prompts.txt')) blocked_prompts=get_prompts_by_file(Path(blocked_prompts_txt_file)) blocked_negative_prompts=get_prompts_by_file(Path(blocked_negative_prompts_txt_file)) +enable_blocked_prompts = True + +def setVal(): + global blocked_prompts_txt_file + global blocked_negative_prompts_txt_file + global blocked_prompts + global blocked_negative_prompts + global enable_blocked_prompts + + blocked_prompts_txt_file = shared.opts.data.get('blocked_prompts_txt_file',blocked_prompts_txt_file) + blocked_negative_prompts_txt_file = shared.opts.data.get('blocked_negative_prompts_txt_file',blocked_negative_prompts_txt_file) + blocked_prompts=get_prompts_by_file(Path(blocked_prompts_txt_file)) + blocked_negative_prompts=get_prompts_by_file(Path(blocked_negative_prompts_txt_file)) + + enable_blocked_prompts = shared.opts.data.get('enable_blocked_prompts',enable_blocked_prompts) + + +splitSign = [',','(',')','[',']','{','}',':','>','\n'] lora_pattern = r'^<[^<>:]' +left_symbol = ['[','{','('] +right_symbol = [']','}',')'] # 把字符串处理成tag或符号 def prompts_to_arr(prompts:str): @@ -49,48 +56,54 @@ def prompts_to_arr(prompts:str): if not is_lora: ls.append(word) ls.append(sub) + word = '' elif sub == '>': is_lora = False word+=sub ls.append(word) + word = '' else: word+=sub else: word+=sub - return [] + print(ls) + return ls return [] def get_prompt(input:str): - return input - -left_symbol = ['[','{','('] -right_symbol = [']','}',')'] + return input.strip().lower() # 过滤掉因为删除屏蔽词后留下的空 -def join_prompts(prompts:str,next:str): - item = next - if re.search(r'^(\s*,\s*)$',item) and re.search(r',\s*$',prompts): - prompts = re.sub(r',\s*$','',prompts) - return join_prompts(prompts,next) - elif re.search(r'^(\s*,\s*)$',item) and prompts[-1] in left_symbol: +def filter_empty(prompts:List[str],next:str): + item = get_prompt(next) + if not prompts: return [next] + if get_prompt(item) == ',' and get_prompt(prompts[-1]) == ',': + prompts = prompts[:-1] + return filter_empty(prompts,next) + elif get_prompt(item) == ',' and prompts[-1] in left_symbol: return prompts elif not item.strip(' ') and prompts[-1] in left_symbol: return prompts elif item in right_symbol and prompts[-1] == ',': prompts = prompts[:-1] - return join_prompts(prompts,next) + return filter_empty(prompts,next) elif item in right_symbol and prompts[-1] in left_symbol and right_symbol.index(item) == left_symbol.index(prompts[-1]) : prompts = prompts[:-1] return prompts else: - prompts += item + prompts += next return prompts def filter_prompts_list(input:List[str],blocked:List[str]): - out_prompts = [item for item in input if get_prompt(item) not in blocked] - prompts = '' - for item in out_prompts: - prompts = join_prompts(prompts,item) + out_prompts = [] + for item in input: + if enable_blocked_prompts and get_prompt(item) in blocked: + continue + if enable_blocked_prompts: + out_prompts = filter_empty(out_prompts,item) + continue + out_prompts.append(item) + prompts = ''.join(out_prompts) return prompts def filter_prompts(prompts:str,blocked:List[str]): @@ -112,10 +125,14 @@ class emptyFilter(scripts.Script): p.all_negative_prompts[i] = filter_prompts(p.all_negative_prompts[i],blocked_negative_prompts) def on_ui_settings(): - section = ("filter-blocked-words", "过滤屏蔽词") + section = ("prompts-filter", "prompts filter") + + shared.opts.add_option("enable_blocked_prompts", shared.OptionInfo(enable_blocked_prompts, "启用屏蔽词过滤", section=section)) shared.opts.add_option("blocked_prompts_txt_file", shared.OptionInfo(blocked_prompts_txt_file, "屏蔽词文件路径", section=section)) shared.opts.add_option("blocked_negative_prompts_txt_file", shared.OptionInfo(blocked_negative_prompts_txt_file, "反向tag的屏蔽词文件路径", section=section)) + + shared.opts.onchange('enable_blocked_prompts', setVal) shared.opts.onchange('blocked_prompts_txt_file', setVal) shared.opts.onchange('blocked_negative_prompts_txt_file', setVal)