From a184065e5b8e3aa608a672c2d93b650645acf28f Mon Sep 17 00:00:00 2001 From: butaixianran Date: Tue, 7 Mar 2023 23:27:19 +0800 Subject: [PATCH] support subfolder, add memory optimised shad256 --- README.md | 10 +- javascript/civitai_helper.js | 49 +++--- scripts/civitai_helper.py | 299 +++++++++++++++++++---------------- 3 files changed, 192 insertions(+), 166 deletions(-) diff --git a/README.md b/README.md index f050778..a06fb68 100644 --- a/README.md +++ b/README.md @@ -62,11 +62,13 @@ If you click Refresh Button of extra network, those additional buttons will be r Enjoy! # Known Issue -* It does not support subfolders for now. Will add this later. * It can not force a model link to civitai by model id for now, will add this later. -* If you did not refresh your extra network cards, but still keep clicking "Refresh Civitai Helper", you'll get more icons. It's a bug, will be fixed. - -All 3 issues above is already in development plan. +# Change Log +## v0.2 +* Support subfolders +* Check if refresh is needed when clicking "Refresh Civitai Helper" +* Add space when adding trigger words +* Add memory optimised shad256 as an option diff --git a/javascript/civitai_helper.js b/javascript/civitai_helper.js index 0a5d0b0..a138364 100644 --- a/javascript/civitai_helper.js +++ b/javascript/civitai_helper.js @@ -25,7 +25,7 @@ function getActiveNegativePrompt() { //button's click function -function open_model_url(event, model_type, model_name){ +function open_model_url(event, model_type, search_term){ console.log("start open_model_url"); //get hidden components of extension @@ -37,7 +37,7 @@ function open_model_url(event, model_type, model_name){ let msg = { "action": "", "model_type": "", - "model_name": "", + "search_term": "", "prompt": "", "neg_prompt": "", } @@ -45,7 +45,7 @@ function open_model_url(event, model_type, model_name){ msg["action"] = "open_url"; msg["model_type"] = model_type; - msg["model_name"] = model_name; + msg["search_term"] = search_term; msg["prompt"] = ""; msg["neg_prompt"] = ""; @@ -64,7 +64,7 @@ function open_model_url(event, model_type, model_name){ } -function add_trigger_words(event, model_type, model_name){ +function add_trigger_words(event, model_type, search_term){ console.log("start add_trigger_words"); //get hidden components of extension @@ -77,14 +77,14 @@ function add_trigger_words(event, model_type, model_name){ let msg = { "action": "", "model_type": "", - "model_name": "", + "search_term": "", "prompt": "", "neg_prompt": "", } msg["action"] = "add_trigger_words"; msg["model_type"] = model_type; - msg["model_name"] = model_name; + msg["search_term"] = search_term; msg["neg_prompt"] = ""; // get active prompt @@ -106,7 +106,7 @@ function add_trigger_words(event, model_type, model_name){ } -function use_preview_prompt(event, model_type, model_name){ +function use_preview_prompt(event, model_type, search_term){ console.log("start use_preview_prompt"); //get hidden components of extension @@ -119,14 +119,14 @@ function use_preview_prompt(event, model_type, model_name){ let msg = { "action": "", "model_type": "", - "model_name": "", + "search_term": "", "prompt": "", "neg_prompt": "", } msg["action"] = "use_preview_prompt"; msg["model_type"] = model_type; - msg["model_name"] = model_name; + msg["search_term"] = search_term; // get active prompt prompt = getActivePrompt(); @@ -178,8 +178,8 @@ onUiLoaded(() => { let addtional_nodes = null; let replace_preview_btn = null; let ul_node = null; - let model_name_node = null; - let model_name = ""; + let search_term_node = null; + let search_term = ""; let model_type = ""; let cards = null; let need_to_add_buttons = false; @@ -233,17 +233,19 @@ onUiLoaded(() => { continue; } - //get model name node - model_name_node = card.querySelector(".actions .name"); - if (!model_name_node){ - console.log("can not find model name node for cards in " + extra_network_id); + + // search_term node + // search_term = subfolder path + model name + ext + search_term_node = card.querySelector(".actions .additional .search_term"); + if (!search_term_node){ + console.log("can not find search_term node for cards in " + extra_network_id); continue; } - // get model name - model_name = model_name_node.innerHTML; - if (!model_name) { - console.log("model name is empty for cards in " + extra_network_id); + // get search_term + search_term = search_term_node.innerHTML; + if (!search_term) { + console.log("search_term is empty for cards in " + extra_network_id); continue; } @@ -257,21 +259,21 @@ onUiLoaded(() => { open_url_node.style.fontSize = "200%"; open_url_node.title = "Open this model's civitai url"; open_url_node.style.margin = "0px 5px"; - open_url_node.setAttribute("onclick","open_model_url(event, '"+model_type+"', '"+model_name+"')"); + open_url_node.setAttribute("onclick","open_model_url(event, '"+model_type+"', '"+search_term+"')"); let add_trigger_words_node = document.createElement("button"); add_trigger_words_node.innerHTML = "💡"; add_trigger_words_node.style.fontSize = "200%"; add_trigger_words_node.title = "Add trigger words to prompt"; add_trigger_words_node.style.margin = "0px 5px"; - add_trigger_words_node.setAttribute("onclick","add_trigger_words(event, '"+model_type+"', '"+model_name+"')"); + add_trigger_words_node.setAttribute("onclick","add_trigger_words(event, '"+model_type+"', '"+search_term+"')"); let use_preview_prompt_node = document.createElement("button"); use_preview_prompt_node.innerHTML = "🏷"; use_preview_prompt_node.style.fontSize = "200%"; use_preview_prompt_node.title = "Use promt from preview image"; use_preview_prompt_node.style.margin = "0px 5px"; - use_preview_prompt_node.setAttribute("onclick","use_preview_prompt(event, '"+model_type+"', '"+model_name+"')"); + use_preview_prompt_node.setAttribute("onclick","use_preview_prompt(event, '"+model_type+"', '"+search_term+"')"); //add to card ul_node.appendChild(open_url_node); @@ -279,9 +281,6 @@ onUiLoaded(() => { ul_node.appendChild(use_preview_prompt_node); - - - } diff --git a/scripts/civitai_helper.py b/scripts/civitai_helper.py index cb501ba..6a65968 100644 --- a/scripts/civitai_helper.py +++ b/scripts/civitai_helper.py @@ -56,108 +56,124 @@ def gen_file_sha256(filname): printD("sha256: " + hash_value) return hash_value +def gen_file_sha256_low_memory(filname): + printD("Using Memory Optimised SHA256") + hash_sha256 = hashlib.sha256() + with open(filname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + + hash_value = hash_sha256.hexdigest() + printD("sha256: " + hash_value) + return hash_value + # scan model to generate SHA256, then use this SHA256 to get model info from civitai -def scan_model(skip_nsfw_preview): +def scan_model(skip_nsfw_preview, low_memory_sha): printD("Start scan_model") for model_type, model_folder in model_folders.items(): folder_path = os.path.join(root_path, model_folder) printD("Scanning path: " + folder_path) - for filename in os.listdir(folder_path): - # check ext - item = os.path.join(folder_path, filename) - base, ext = os.path.splitext(item) - if ext in model_exts: - # find a model - # get preview image - first_preview = base+".png" - sec_preview = base+".preview.png" - # get info file - info_file = base + civitai_info_suffix + model_info_exts - # check info file - if not os.path.isfile(info_file): - # get model's sha256 - printD("Generate SHA256 for model: " + filename) - hash = gen_file_sha256(item) - - if not hash: - printD("failed generate SHA256 for this file.") - return - - # use this sha256 to get model info from civitai - printD("Request model info from civitai") - r = requests.get(civitai_hash_api_url+hash) - if not r.ok: - if r.status_code == 404: - # this is not a civitai model - printD("Civitai does not have this model") - printD("Write empty model info file") - empty_info = {} - with open(info_file, 'w') as f: - data = json.dumps(empty_info) - f.write(data) - # go to next file - continue + for root, dirs, files in os.walk(folder_path): + for filename in files: + # check ext + item = os.path.join(root, filename) + base, ext = os.path.splitext(item) + if ext in model_exts: + # find a model + # get preview image + first_preview = base+".png" + sec_preview = base+".preview.png" + # get info file + info_file = base + civitai_info_suffix + model_info_exts + # check info file + if not os.path.isfile(info_file): + # get model's sha256 + printD("Generate SHA256 for model: " + filename) + hash = "" + if low_memory_sha: + hash = gen_file_sha256_low_memory(item) else: - printD("Get errorcode: " + str(r.status_code)) + hash = gen_file_sha256(item) + + if not hash: + printD("failed generate SHA256 for this file.") + return + + # use this sha256 to get model info from civitai + printD("Request model info from civitai") + r = requests.get(civitai_hash_api_url+hash) + if not r.ok: + if r.status_code == 404: + # this is not a civitai model + printD("Civitai does not have this model") + printD("Write empty model info file") + empty_info = {} + with open(info_file, 'w') as f: + data = json.dumps(empty_info) + f.write(data) + # go to next file + continue + else: + printD("Get errorcode: " + str(r.status_code)) + printD(r.text) + return + + # try to get content + content = None + try: + content = r.json() + except Exception as e: + printD("Parse response json failed") + printD(str(e)) + printD("response:") printD(r.text) return + + if not content: + printD("error, content from civitai is None") + return + + # write model info to file + printD("Write model info to file: " + info_file) + with open(info_file, 'w') as f: + data = json.dumps(content) + f.write(data) - # try to get content - content = None - try: - content = r.json() - except Exception as e: - printD("Parse response json failed") - printD(str(e)) - printD("response:") - printD(r.text) - return - - if not content: - printD("error, content from civitai is None") - return - - # write model info to file - printD("Write model info to file: " + info_file) - with open(info_file, 'w') as f: - data = json.dumps(content) - f.write(data) - - # check preview image - if not os.path.isfile(sec_preview): - # need to download preview image - printD("Need preview image for this model") - if content["images"]: - for img_dict in content["images"]: - if "nsfw" in img_dict.keys(): - if img_dict["nsfw"]: - printD("This image is NSFW") - if skip_nsfw_preview: - printD("Skip NSFW image") - continue - - if "url" in img_dict.keys(): - printD("Sending request for image: " + img_dict["url"]) - # get image - img_r = requests.get(img_dict["url"], stream=True) - if not img_r.ok: - printD("Get errorcode: " + str(r.status_code)) - printD(r.text) - return + # check preview image + if not os.path.isfile(sec_preview): + # need to download preview image + printD("Need preview image for this model") + if content["images"]: + for img_dict in content["images"]: + if "nsfw" in img_dict.keys(): + if img_dict["nsfw"]: + printD("This image is NSFW") + if skip_nsfw_preview: + printD("Skip NSFW image") + continue - # write to file - with open(sec_preview, 'wb') as f: - img_r.raw.decode_content = True - shutil.copyfileobj(img_r.raw, f) + if "url" in img_dict.keys(): + printD("Sending request for image: " + img_dict["url"]) + # get image + img_r = requests.get(img_dict["url"], stream=True) + if not img_r.ok: + printD("Get errorcode: " + str(r.status_code)) + printD(r.text) + return + + # write to file + with open(sec_preview, 'wb') as f: + img_r.raw.decode_content = True + shutil.copyfileobj(img_r.raw, f) - printD("Created Preview image: " + sec_preview) + printD("Created Preview image: " + sec_preview) - # we only need 1 preview image - break + # we only need 1 preview image + break - # for testing, we only check 1 model for each type - # break + # for testing, we only check 1 model for each type + # break printD("End scan_model") @@ -165,7 +181,7 @@ def scan_model(skip_nsfw_preview): # handle request from javascript # parameter: msg - msg from js -# return: (action, model_type, model_name, prompt, neg_prompt) +# return: (action, model_type, search_term, prompt, neg_prompt) def parse_js_msg(msg): printD("Start parse js msg") msg_dict = json.loads(msg) @@ -178,8 +194,8 @@ def parse_js_msg(msg): printD("Can not find model type from js request") return - if "model_name" not in msg_dict.keys(): - printD("Can not find model name from js request") + if "search_term" not in msg_dict.keys(): + printD("Can not find search_term from js request") return if "prompt" not in msg_dict.keys(): @@ -192,7 +208,7 @@ def parse_js_msg(msg): action = msg_dict["action"] model_type = msg_dict["model_type"] - model_name = msg_dict["model_name"] + search_term = msg_dict["search_term"] prompt = msg_dict["prompt"] neg_prompt = msg_dict["neg_prompt"] @@ -204,8 +220,8 @@ def parse_js_msg(msg): printD("model_type from js request is None") return - if not model_name: - printD("model_name from js request is None") + if not search_term: + printD("search_term from js request is None") return @@ -219,21 +235,26 @@ def parse_js_msg(msg): printD("End parse js msg") - return (action, model_type, model_name, prompt, neg_prompt) - - + return (action, model_type, search_term, prompt, neg_prompt) # get model info file's content by model type and model name -# parameter: model_type, model_name +# parameter: model_type, search_term # return: model_info_dict -def get_model_info(model_type, model_name): +def get_model_info(model_type, search_term): if model_type not in model_folders.keys(): printD("unknow model type: " + model_type) return None + + # search_term = subfolderpath + model name + ext. And it always start with a / even there is no sub folder + base, ext = os.path.splitext(search_term) + model_info_base = base + if base[:1] == "/": + model_info_base = base[1:] + model_folder = model_folders[model_type] - model_info_filename = model_name + civitai_info_suffix + model_info_exts + model_info_filename = model_info_base + civitai_info_suffix + model_info_exts model_info_filepath = os.path.join(root_path, model_folder, model_info_filename) if not os.path.isfile(model_info_filepath): @@ -253,7 +274,7 @@ def get_model_info(model_type, model_name): # get civitai's model url and open it in browser -# parameter: model_type, model_name +# parameter: model_type, search_term def open_model_url(msg): printD("Start open_model_url") @@ -261,23 +282,23 @@ def open_model_url(msg): if not result: printD("Parsing js ms failed") return - - action, model_type, model_name, prompt, neg_prompt = result - model_info = get_model_info(model_type, model_name) + action, model_type, search_term, prompt, neg_prompt = result + + model_info = get_model_info(model_type, search_term) if not model_info: - printD(f"Failed to get model info for {model_type} {model_name}") + printD(f"Failed to get model info for {model_type} {search_term}") return if "modelId" not in model_info.keys(): - printD(f"Failed to get model id from info file for {model_type} {model_name}") + printD(f"Failed to get model id from info file for {model_type} {search_term}") return - + model_id = model_info["modelId"] if not model_id: - printD(f"model id from info file of {model_type} {model_name} is None") + printD(f"model id from info file of {model_type} {search_term} is None") return - + url = "https://civitai.com/models/"+str(model_id) printD("Open Url: " + url) @@ -288,7 +309,7 @@ def open_model_url(msg): # add trigger words to prompt -# parameter: model_type, model_name, prompt +# parameter: model_type, search_term, prompt # return: [new_prompt, new_prompt] - new prompt with trigger words, return twice for txt2img and img2img def add_trigger_words(msg): printD("Start add_trigger_words") @@ -298,33 +319,33 @@ def add_trigger_words(msg): printD("Parsing js ms failed") return - action, model_type, model_name, prompt, neg_prompt = result + action, model_type, search_term, prompt, neg_prompt = result - model_info = get_model_info(model_type, model_name) + model_info = get_model_info(model_type, search_term) if not model_info: - printD(f"Failed to get model info for {model_type} {model_name}") + printD(f"Failed to get model info for {model_type} {search_term}") return [prompt, prompt] if "trainedWords" not in model_info.keys(): - printD(f"Failed to get trainedWords from info file for {model_type} {model_name}") + printD(f"Failed to get trainedWords from info file for {model_type} {search_term}") return [prompt, prompt] trainedWords = model_info["trainedWords"] if not trainedWords: - printD(f"No trainedWords from info file for {model_type} {model_name}") + printD(f"No trainedWords from info file for {model_type} {search_term}") return [prompt, prompt] if len(trainedWords) == 0: - printD(f"trainedWords from info file for {model_type} {model_name} is empty") + printD(f"trainedWords from info file for {model_type} {search_term} is empty") return [prompt, prompt] # get ful trigger words trigger_words = "" for word in trainedWords: - trigger_words = trigger_words + word + trigger_words = trigger_words + word + ", " - new_prompt = prompt + trigger_words + new_prompt = prompt + " " + trigger_words printD("trigger_words: " + trigger_words) printD("prompt: " + prompt) printD("new_prompt: " + new_prompt) @@ -347,25 +368,25 @@ def use_preview_image_prompt(msg): printD("Parsing js ms failed") return - action, model_type, model_name, prompt, neg_prompt = result + action, model_type, search_term, prompt, neg_prompt = result - model_info = get_model_info(model_type, model_name) + model_info = get_model_info(model_type, search_term) if not model_info: - printD(f"Failed to get model info for {model_type} {model_name}") + printD(f"Failed to get model info for {model_type} {search_term}") return [prompt, neg_prompt, prompt, neg_prompt] if "images" not in model_info.keys(): - printD(f"Failed to get images from info file for {model_type} {model_name}") + printD(f"Failed to get images from info file for {model_type} {search_term}") return [prompt, neg_prompt, prompt, neg_prompt] images = model_info["images"] if not images: - printD(f"No images from info file for {model_type} {model_name}") + printD(f"No images from info file for {model_type} {search_term}") return [prompt, neg_prompt, prompt, neg_prompt] if len(images) == 0: - printD(f"images from info file for {model_type} {model_name} is empty") + printD(f"images from info file for {model_type} {search_term} is empty") return [prompt, neg_prompt, prompt, neg_prompt] # get prompt from preview images' meta data @@ -373,20 +394,21 @@ def use_preview_image_prompt(msg): preview_neg_prompt = "" for img in images: if "meta" in img.keys(): - if "prompt" in img["meta"].keys(): - if img["meta"]["prompt"]: - preview_prompt = img["meta"]["prompt"] - - if "negativePrompt" in img["meta"].keys(): - if img["meta"]["negativePrompt"]: - preview_neg_prompt = img["meta"]["negativePrompt"] + if img["meta"]: + if "prompt" in img["meta"].keys(): + if img["meta"]["prompt"]: + preview_prompt = img["meta"]["prompt"] + + if "negativePrompt" in img["meta"].keys(): + if img["meta"]["negativePrompt"]: + preview_neg_prompt = img["meta"]["negativePrompt"] - # we only need 1 prompt - if preview_prompt: - break + # we only need 1 prompt + if preview_prompt: + break if not preview_prompt: - printD(f"There is no prompt of {model_type} {model_name} in its preview image") + printD(f"There is no prompt of {model_type} {search_term} in its preview image") return [prompt, neg_prompt, prompt, neg_prompt] printD("End use_preview_image_prompt") @@ -412,8 +434,11 @@ def on_ui_tabs(): with gr.Blocks(analytics_enabled=False) as civitai_helper: # info gr.Markdown("Civitai Helper's extension tab") + + with gr.Row(): + skip_nsfw_preview_ckb = gr.Checkbox(label="SKip NSFW Preview images", value=False, elem_id="ch_skip_nsfw_preview_ckb") + low_memory_sha_ckb = gr.Checkbox(label="Memory Optimised SHA256", value=False, elem_id="ch_low_memory_sha_ckb") - skip_nsfw_preview_ckb = gr.Checkbox(label="SKip NSFW Preview images", value=False, elem_id="ch_skip_nsfw_preview_ckb") scan_model_btn = gr.Button(value="Scan model", elem_id="ch_scan_model_btn") # hidden component for js @@ -423,7 +448,7 @@ def on_ui_tabs(): js_use_preview_prompt_btn = gr.Button(value="Use Prompt from Preview Image", visible=False, elem_id="ch_js_use_preview_prompt_btn") # ====events==== - scan_model_btn.click(scan_model, inputs=[skip_nsfw_preview_ckb]) + scan_model_btn.click(scan_model, inputs=[skip_nsfw_preview_ckb, low_memory_sha_ckb]) js_open_url_btn.click(open_model_url, inputs=[js_msg_txtbox]) js_add_trigger_words_btn.click(add_trigger_words, inputs=[js_msg_txtbox], outputs=[txt2img_prompt, img2img_prompt]) js_use_preview_prompt_btn.click(use_preview_image_prompt, inputs=[js_msg_txtbox], outputs=[txt2img_prompt, txt2img_neg_prompt, img2img_prompt, img2img_neg_prompt])