From ff5ef8977df45ab9d9d234b8b3e435d79dc72177 Mon Sep 17 00:00:00 2001 From: butaixianran Date: Wed, 8 Mar 2023 19:18:46 +0800 Subject: [PATCH] start refactoring --- javascript/civitai_helper.js | 4 +- scripts/civitai_helper.py | 28 ---- scripts/lib/civitai.py | 89 +++++++++++++ scripts/lib/model.py | 66 ++++++++++ scripts/lib/msg.py | 90 +++++++++++++ scripts/lib/operator.py | 247 +++++++++++++++++++++++++++++++++++ scripts/lib/setting.py | 74 +++++++++++ scripts/lib/util.py | 43 ++++++ 8 files changed, 611 insertions(+), 30 deletions(-) create mode 100644 scripts/lib/civitai.py create mode 100644 scripts/lib/model.py create mode 100644 scripts/lib/msg.py create mode 100644 scripts/lib/operator.py create mode 100644 scripts/lib/setting.py create mode 100644 scripts/lib/util.py diff --git a/javascript/civitai_helper.js b/javascript/civitai_helper.js index 1789a0c..ec4c467 100644 --- a/javascript/civitai_helper.js +++ b/javascript/civitai_helper.js @@ -315,8 +315,6 @@ onUiLoaded(() => { } - //run it once - update_card_for_civitai(); let tab_id = "" let extra_tab = null; @@ -346,6 +344,8 @@ onUiLoaded(() => { } + //run it once + update_card_for_civitai(); }); diff --git a/scripts/civitai_helper.py b/scripts/civitai_helper.py index a240234..23247d9 100644 --- a/scripts/civitai_helper.py +++ b/scripts/civitai_helper.py @@ -27,34 +27,6 @@ from modules import shared def printD(msg): print(f"Civitai Helper: {msg}") -# printD("Current Model folder:") -# if shared.cmd_opts.embeddings_dir: -# printD("ti: " + shared.cmd_opts.embeddings_dir) -# else: -# printD("shared.cmd_opts.embeddings_dir is None") - -# if shared.cmd_opts.hypernetwork_dir: -# printD("hypernetwork_dir: " + shared.cmd_opts.hypernetwork_dir) -# else: -# printD("shared.cmd_opts.embeddings_dir is None") - - -# if shared.cmd_opts.ckpt_dir: -# printD("ckpt_dir: " + shared.cmd_opts.ckpt_dir) -# else: -# printD("shared.cmd_opts.ckpt_dir is None") - - -# if shared.cmd_opts.lora_dir: -# printD("lora_dir: " + shared.cmd_opts.lora_dir) -# else: -# printD("shared.cmd_opts.lora_dir is None") - - - - -# init -config_file_name = os.path.join(scripts.basedir(), "setting.json") # this is the default root path root_path = os.getcwd() diff --git a/scripts/lib/civitai.py b/scripts/lib/civitai.py new file mode 100644 index 0000000..4d19ddf --- /dev/null +++ b/scripts/lib/civitai.py @@ -0,0 +1,89 @@ +# -*- coding: UTF-8 -*- +# handle msg between js and python side +import os +import json +import re +import requests +from . import util +from . import model + + +suffix = ".civitai" + +url_dict = { + "modelPage":"https://civitai.com/models/", + "modelId": "https://civitai.com/api/v1/models/", + "modelVersionId": "https://civitai.com/api/v1/model-versions/", + "hash": "https://civitai.com/api/v1/model-versions/by-hash/" +} + +# get image with full size +# width is in number, not string +# return: url str +def get_full_size_image_url(image_url, width): + return re.sub('/width=\d+/', '/width=' + str(width) + '/', image_url) + + +# use this sha256 to get model info from civitai +# return: model info dict +def get_model_info_by_hash(hash:str): + util.printD("Request model info from civitai") + + if not hash: + util.printD("hash is empty") + return + + r = requests.get(url_dict["hash"]+hash) + if not r.ok: + if r.status_code == 404: + # this is not a civitai model + util.printD("Civitai does not have this model") + return {} + else: + util.printD("Get error code: " + str(r.status_code)) + util.printD(r.text) + return + + # try to get content + content = None + try: + content = r.json() + except Exception as e: + util.printD("Parse response json failed") + util.printD(str(e)) + util.printD("response:") + util.printD(r.text) + return + + if not content: + util.printD("error, content from civitai is None") + return + + return content + + + +# get model info file's content by model type and search_term +# parameter: model_type, search_term +# return: model_info +def load_model_info_by_search_term(model_type, search_term): + util.printD(f"Load model info of {search_term} in {model_type}") + if model_type not in model.folders.keys(): + util.printD("unknow model type: " + model_type) + return + + # search_term = subfolderpath + model name + ext. And it always start with a / even there is no sub folder + base, ext = os.path.splitext(search_term) + model_info_base = base + if base[:1] == "/": + model_info_base = base[1:] + + model_folder = model.folders[model_type] + model_info_filename = model_info_base + suffix + model.info_ext + model_info_filepath = os.path.join(model_folder, model_info_filename) + + if not os.path.isfile(model_info_filepath): + util.printD("Can not find model info file: " + model_info_filepath) + return + + return model.load_model_info(model_info_filepath) \ No newline at end of file diff --git a/scripts/lib/model.py b/scripts/lib/model.py new file mode 100644 index 0000000..9594de4 --- /dev/null +++ b/scripts/lib/model.py @@ -0,0 +1,66 @@ +# -*- coding: UTF-8 -*- +# handle msg between js and python side +import os +import json +from . import util +from modules import shared + + +# this is the default root path +root_path = os.getcwd() + +# if command line arguement is used to change model folder, +# then model folder is in absolute path, not based on this root path anymore. +# so to make extension work with those absolute model folder paths, model folder also need to be in absolute path +folders = { + "ti": os.path.join(root_path, "embeddings"), + "hyper": os.path.join(root_path, "models", "hypernetworks"), + "ckp": os.path.join(root_path, "models", "Stable-diffusion"), + "lora": os.path.join(root_path, "models", "Lora"), +} + +exts = (".bin", ".pt", ".safetensors", ".ckpt") +info_ext = ".info" + + + +# get cusomter model path +def get_custom_model_folder(): + global folders + + if shared.cmd_opts.embeddings_dir and os.path.isdir(shared.cmd_opts.embeddings_dir): + folders["ti"] = shared.cmd_opts.embeddings_dir + + if shared.cmd_opts.hypernetwork_dir and os.path.isdir(shared.cmd_opts.hypernetwork_dir): + folders["hyper"] = shared.cmd_opts.hypernetwork_dir + + if shared.cmd_opts.ckpt_dir and os.path.isdir(shared.cmd_opts.ckpt_dir): + folders["ckp"] = shared.cmd_opts.ckpt_dir + + if shared.cmd_opts.lora_dir and os.path.isdir(shared.cmd_opts.lora_dir): + folders["lora"] = shared.cmd_opts.lora_dir + +get_custom_model_folder() + + + +# write model info to file +def write_model_info(path, model_info): + util.printD("Write model info to file: " + path) + with open(path, 'w') as f: + f.write(json.dumps(model_info, indent=4)) + + +def load_model_info(path): + util.printD("Load model info from file: " + path) + model_info = None + with open(path, 'r') as f: + try: + model_info = json.load(f) + except Exception as e: + util.printD("Selected file is not json: " + path) + util.printD(e) + return + + return model_info + diff --git a/scripts/lib/msg.py b/scripts/lib/msg.py new file mode 100644 index 0000000..79a6e54 --- /dev/null +++ b/scripts/lib/msg.py @@ -0,0 +1,90 @@ +# -*- coding: UTF-8 -*- +# handle msg between js and python side +import json +from . import util + +# action list +js_actions = ("open_url", "add_trigger_words", "use_preview_prompt") +py_actions = ("open_url", "scan_log", "model_new_version") + + +# handle request from javascript +# parameter: msg - msg from js as string in a hidden textbox +# return: (action, model_type, search_term, prompt, neg_prompt) +def parse_js_msg(msg): + util.printD("Start parse js msg") + msg_dict = json.loads(msg) + + if "action" not in msg_dict.keys(): + util.printD("Can not find action from js request") + return + + if "model_type" not in msg_dict.keys(): + util.printD("Can not find model type from js request") + return + + if "search_term" not in msg_dict.keys(): + util.printD("Can not find search_term from js request") + return + + if "prompt" not in msg_dict.keys(): + util.printD("Can not find prompt from js request") + return + + if "neg_prompt" not in msg_dict.keys(): + util.printD("Can not find neg_prompt from js request") + return + + action = msg_dict["action"] + model_type = msg_dict["model_type"] + search_term = msg_dict["search_term"] + prompt = msg_dict["prompt"] + neg_prompt = msg_dict["neg_prompt"] + + if not action: + util.printD("Action from js request is None") + return + + if not model_type: + util.printD("model_type from js request is None") + return + + if not search_term: + util.printD("search_term from js request is None") + return + + + if action not in js_actions: + util.printD("Unknow action: " + action) + return + + util.printD("End parse js msg") + + return (action, model_type, search_term, prompt, neg_prompt) + + +# build python side msg for sending to js +# parameter: content dict +# return: msg as string, to fill into a hidden textbox +def build_py_msg(action:str, content:dict): + util.printD("Start build_msg") + if not content: + util.printD("Content is None") + return + + if not action: + util.printD("Action is None") + return + + if action not in py_actions: + util.printD("Unknow action: " + action) + return + + msg = { + "action" : action, + "content": content + } + + + util.printD("End build_msg") + return json.dumps(msg) \ No newline at end of file diff --git a/scripts/lib/operator.py b/scripts/lib/operator.py new file mode 100644 index 0000000..54e62f8 --- /dev/null +++ b/scripts/lib/operator.py @@ -0,0 +1,247 @@ +# -*- coding: UTF-8 -*- +# handle msg between js and python side +import os +import json +import requests +import shutil +import webbrowser +from . import util +from . import model +from . import civitai +from . import msg + + + +# scan model to generate SHA256, then use this SHA256 to get model info from civitai +def scan_model(low_memory_sha, max_size_preview, readable_model_info, skip_nsfw_preview): + util.printD("Start scan_model") + + model_count = 0 + image_count = 0 + scan_log = "" + for model_type, model_folder in model.folders.items(): + util.printD("Scanning path: " + model_folder) + for root, dirs, files in os.walk(model_folder): + for filename in files: + # check ext + item = os.path.join(root, filename) + base, ext = os.path.splitext(item) + if ext in model.exts: + # find a model + # set a Progress log + scan_log = "Scanned: " + str(model_count) + ", Scanning: "+ filename + # try to update to UI Here + # Is still trying to find a way + + # get preview image + first_preview = base+".png" + sec_preview = base+".preview.png" + # get info file + info_file = base + civitai.suffix + model.info_ext + # check info file + if not os.path.isfile(info_file): + # get model's sha256 + util.printD("Generate SHA256 for model: " + filename) + hash = util.gen_file_sha256(item, low_memory_sha) + + if not hash: + util.printD("failed generate SHA256 for this file.") + return + + # use this sha256 to get model info from civitai + model_info = civitai.get_model_info_by_hash(hash) + if model_info is None: + util.printD("Fail to get model_info") + return + + # write model info to file + model.write_model_info(info_file, model_info) + + # set model_count + model_count = model_count+1 + + # check preview image + if not os.path.isfile(sec_preview): + # need to download preview image + util.printD("Need preview image for this model") + # load model_info file + if os.path.isfile(info_file): + model_info = model.load_model_info(info_file) + if not model_info: + util.printD("Model Info is empty") + continue + + if "images" in model_info.keys(): + if model_info["images"]: + for img_dict in model_info["images"]: + if "nsfw" in img_dict.keys(): + if img_dict["nsfw"]: + util.printD("This image is NSFW") + if skip_nsfw_preview: + util.printD("Skip NSFW image") + continue + + if "url" in img_dict.keys(): + img_url = img_dict["url"] + if max_size_preview: + # use max width + if "width" in img_dict.keys(): + if img_dict["width"]: + img_url = civitai.get_full_size_image_url(img_url, img_dict["width"]) + + util.download_file(img_url, sec_preview) + image_count = image_count + 1 + # we only need 1 preview image + break + + + scan_log = "Done" + + util.printD("End scan_model") + + + +# get civitai's model url and open it in browser +# parameter: model_type, search_term +def open_model_url(msg): + util.printD("Start open_model_url") + + result = msg.parse_js_msg(msg) + if not result: + util.printD("Parsing js ms failed") + return + + action, model_type, search_term, prompt, neg_prompt = result + + model_info = civitai.load_model_info_by_search_term(model_type, search_term) + if not model_info: + util.printD(f"Failed to get model info for {model_type} {search_term}") + return + + if "modelId" not in model_info.keys(): + util.printD(f"Failed to get model id from info file for {model_type} {search_term}") + return + + model_id = model_info["modelId"] + if not model_id: + util.printD(f"model id from info file of {model_type} {search_term} is None") + return + + url = civitai.url_dict["modelPage"]+str(model_id) + + util.printD("Open Url: " + url) + # open url + webbrowser.open_new_tab(url) + + util.printD("End open_model_url") + + + +# add trigger words to prompt +# parameter: model_type, search_term, prompt +# return: [new_prompt, new_prompt] - new prompt with trigger words, return twice for txt2img and img2img +def add_trigger_words(msg): + util.printD("Start add_trigger_words") + + result = msg.parse_js_msg(msg) + if not result: + util.printD("Parsing js ms failed") + return + + action, model_type, search_term, prompt, neg_prompt = result + + + model_info = civitai.load_model_info_by_search_term(model_type, search_term) + if not model_info: + util.printD(f"Failed to get model info for {model_type} {search_term}") + return [prompt, prompt] + + if "trainedWords" not in model_info.keys(): + util.printD(f"Failed to get trainedWords from info file for {model_type} {search_term}") + return [prompt, prompt] + + trainedWords = model_info["trainedWords"] + if not trainedWords: + util.printD(f"No trainedWords from info file for {model_type} {search_term}") + return [prompt, prompt] + + if len(trainedWords) == 0: + util.printD(f"trainedWords from info file for {model_type} {search_term} is empty") + return [prompt, prompt] + + # get ful trigger words + trigger_words = "" + for word in trainedWords: + trigger_words = trigger_words + word + ", " + + new_prompt = prompt + " " + trigger_words + util.printD("trigger_words: " + trigger_words) + util.printD("prompt: " + prompt) + util.printD("new_prompt: " + new_prompt) + + util.printD("End add_trigger_words") + + # add to prompt + return [new_prompt, new_prompt] + + + +# use preview image's prompt as prompt +# parameter: model_type, model_name, prompt, neg_prompt +# return: [new_prompt, new_neg_prompt, new_prompt, new_neg_prompt,] - return twice for txt2img and img2img +def use_preview_image_prompt(msg): + util.printD("Start use_preview_image_prompt") + + result = msg.parse_js_msg(msg) + if not result: + util.printD("Parsing js ms failed") + return + + action, model_type, search_term, prompt, neg_prompt = result + + + model_info = civitai.load_model_info_by_search_term(model_type, search_term) + if not model_info: + util.printD(f"Failed to get model info for {model_type} {search_term}") + return [prompt, neg_prompt, prompt, neg_prompt] + + if "images" not in model_info.keys(): + util.printD(f"Failed to get images from info file for {model_type} {search_term}") + return [prompt, neg_prompt, prompt, neg_prompt] + + images = model_info["images"] + if not images: + util.printD(f"No images from info file for {model_type} {search_term}") + return [prompt, neg_prompt, prompt, neg_prompt] + + if len(images) == 0: + util.printD(f"images from info file for {model_type} {search_term} is empty") + return [prompt, neg_prompt, prompt, neg_prompt] + + # get prompt from preview images' meta data + preview_prompt = "" + preview_neg_prompt = "" + for img in images: + if "meta" in img.keys(): + if img["meta"]: + if "prompt" in img["meta"].keys(): + if img["meta"]["prompt"]: + preview_prompt = img["meta"]["prompt"] + + if "negativePrompt" in img["meta"].keys(): + if img["meta"]["negativePrompt"]: + preview_neg_prompt = img["meta"]["negativePrompt"] + + # we only need 1 prompt + if preview_prompt: + break + + if not preview_prompt: + util.printD(f"There is no prompt of {model_type} {search_term} in its preview image") + return [prompt, neg_prompt, prompt, neg_prompt] + + util.printD("End use_preview_image_prompt") + + return [preview_prompt, preview_neg_prompt, preview_prompt, preview_neg_prompt] + + diff --git a/scripts/lib/setting.py b/scripts/lib/setting.py new file mode 100644 index 0000000..d6d8a66 --- /dev/null +++ b/scripts/lib/setting.py @@ -0,0 +1,74 @@ +# -*- coding: UTF-8 -*- +# collecting settings to here +import json +import os +import modules.scripts as scripts +from . import util + + +name = "setting.json" +path = os.path.join(scripts.basedir(), name) + +data = { + "model":{ + "low_memory_sha": True, + "max_size_preview": True, + "readable_model_info": True, + "skip_nsfw_preview": False + }, + "general":{ + "open_url_with_js": False, + "check_model_version_at_startup": False, + }, + "tool":{ + } +} + + + +# save setting +def save(): + print("Saving tranlation service setting...") + # write data into globel trans_setting + global trans_setting + + + + # to json + json_data = json.dumps(data) + + #write to file + try: + with open(path, 'w') as f: + f.write(json_data) + except Exception as e: + util.printD("Error when writing file:"+path) + util.printD(str(e)) + return + + util.printD("Setting saved to: " + path) + + +# load setting to global data +def load(): + # load data into globel data + global data + + util.printD("Load setting from: " + path) + + if not os.path.isfile(path): + util.printD("No setting file, use default") + return + + json_data = None + with open(path, 'r') as f: + json_data = json.load(f) + + # check error + if not json_data: + util.printD("load setting file failed") + return + + data = json_data + + return \ No newline at end of file diff --git a/scripts/lib/util.py b/scripts/lib/util.py new file mode 100644 index 0000000..3e31e59 --- /dev/null +++ b/scripts/lib/util.py @@ -0,0 +1,43 @@ +# -*- coding: UTF-8 -*- +import hashlib +import requests +import shutil + +# print for debugging +def printD(msg): + print(f"Civitai Helper: {msg}") + + +def gen_file_sha256(filname, is_low_memory=True): + printD("Calculate SHA256") + hash_sha256 = hashlib.sha256() + with open(filname, "rb") as f: + if is_low_memory: + printD("Using Memory Optimised SHA256") + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + else: + hash_sha256.update(f.read()) + + + hash_value = hash_sha256.hexdigest() + printD("sha256: " + hash_value) + return hash_value + + +# get preview image +def download_file(url, path): + printD("Download file from: " + url) + # get file + r = requests.get(url, stream=True) + if not r.ok: + printD("Get error code: " + str(r.status_code)) + printD(r.text) + return + + # write to file + with open(path, 'wb') as f: + r.raw.decode_content = True + shutil.copyfileobj(r.raw, f) + + printD("File downloaded to: " + path)