diff --git a/javascript/main.js b/javascript/main.js index 677d70b..8d5f903 100644 --- a/javascript/main.js +++ b/javascript/main.js @@ -126,7 +126,7 @@ function handleRecordSave() { } const textArea = findElem('mo-description-output-widget').querySelector('textarea') - const event = new Event('input', {'bubbles': true, "composed": true}); + const event = new Event('input', { 'bubbles': true, "composed": true }); textArea.value = output findElem('mo-description-output-widget').querySelector('textarea').dispatchEvent(event); logMo('Description content dispatched: ' + output) @@ -345,7 +345,9 @@ function navigateBack() { return [] } -function navigateDetails(id) { +function navigateDetails(id, event) { + event.stopPropagation(); + event.preventDefault(); logMo('Navigate details screen for id: ' + id) const navObj = { screen: "details", @@ -380,7 +382,7 @@ function navigateImportExport(filter_state) { return [] } -function navigateDebug() { +function navigateDebug(event) { logMo('Navigate debug screen') const navObj = { screen: "debug", @@ -391,7 +393,9 @@ function navigateDebug() { return [] } -function navigateEdit(id) { +function navigateEdit(id, event) { + event.stopPropagation(); + event.preventDefault(); logMo('Navigate edit screen for id: ' + id) const navObj = { screen: "edit", @@ -403,7 +407,9 @@ function navigateEdit(id) { return [] } -function navigateEditPrefilled(json_data) { +function navigateEditPrefilled(json_data, event) { + event.stopPropagation(); + event.preventDefault(); logMo('Navigate edit screen for prefilled json: ' + json_data) const navObj = { screen: "edit", @@ -411,11 +417,34 @@ function navigateEditPrefilled(json_data) { token: generateUUID(), backstack: populateBackstack() }; + + + setTimeout((event) => { + //Get config options + var json_elem = gradioApp().getElementById('settings_json'); + if (json_elem == null) return; + + var textarea = json_elem.querySelector('textarea'); + var jsdata = textarea.value; + opts = JSON.parse(jsdata); + + if (opts['mo_autobind_file']) { + //setTimeout((event) => { + var bind = gradioApp().querySelector('#model_organizer_add_bind input'); + var modelName = gradioApp().querySelector('#model_organizer_edit_name input'); + bind.value = modelName.value; + } + }, 300); + deliverNavObject(navObj) + + return [] } -function navigateDownloadRecord(id) { +function navigateDownloadRecord(id, event) { + event.stopPropagation(); + event.preventDefault(); logMo('Navigate download screen for id: ' + id) const navObj = { screen: "download", @@ -451,7 +480,9 @@ function navigateDownloadGroup(groupName) { return [] } -function navigateRemove(id) { +function navigateRemove(id, event) { + event.stopPropagation(); + event.preventDefault(); logMo('Navigate removal screen for id: ' + id) const navObj = { screen: "remove", @@ -466,7 +497,7 @@ function navigateRemove(id) { function deliverNavObject(navObj) { const navJson = JSON.stringify(navObj); const textArea = findElem('mo_json_nav_box').querySelector('textarea') - const event = new Event('input', {'bubbles': true, "composed": true}); + const event = new Event('input', { 'bubbles': true, "composed": true }); textArea.value = navJson findElem('mo_json_nav_box').querySelector('textarea').dispatchEvent(event); logMo('JSON Nav dispatched: ' + navJson) @@ -478,7 +509,7 @@ function invokeHomeInitialStateLoad() { const initialStateTextArea = findElem('mo-initial-state-box').querySelector('textarea') const stateTextArea = findElem('mo-home-state-box').querySelector('textarea') stateTextArea.value = initialStateTextArea.value - const event = new Event('input', {'bubbles': true, "composed": true}); + const event = new Event('input', { 'bubbles': true, "composed": true }); findElem('mo-home-state-box').querySelector('textarea').dispatchEvent(event); isHomeInitialStateInvoked = true logMo('initial home state invoked') @@ -510,15 +541,15 @@ function getTheme() { function getCardsSize() { return new Promise((resolve) => { - fetch(origin + '/mo/display-options') - .then(response => response.json()) - .then(data => { - resolve([data.card_width, data.card_height]) - }) - .catch(_ => { - resolve([250, 350]) - }); - } + fetch(origin + '/mo/display-options') + .then(response => response.json()) + .then(data => { + resolve([data.card_width, data.card_height]) + }) + .catch(_ => { + resolve([250, 350]) + }); + } ) } @@ -573,3 +604,146 @@ onUiLoaded(function () { installCardsSize(size[0], size[1]) }) }) + +let organizerTab = null; +let lastTabName = 'txt2img'; + +const inputEvent = new Event('input', { 'bubbles': true, "composed": true }); +const changeEvent = new Event('change', { 'bubbles': true, "composed": true }); +// Extra networks tab integration +// Huge thanks to https://github.com/CurtisDS/sd-model-preview-xd/tree/main for how to do this +onUiUpdate(function () { + + // get the organizer tab + let tabs = gradioApp().querySelectorAll("#tabs > div:first-of-type button"); + if (typeof tabs != "undefined" && tabs != null && tabs.length > 0) { + tabs.forEach(tab => { + if (tab.innerText == "Model Organizer") { + organizerTab = tab; + } + }); + } + + // Get + let thumbCards = gradioApp().querySelectorAll("#txt2img_extra_tabs .card:not([organizer-hijack]), #img2img_extra_tabs .card:not([organizer-hijack])"); + if (typeof thumbCards != "undefined" && thumbCards != null && thumbCards.length > 0) { + thumbCards.forEach(card => { + let buttonRow = card.querySelector('.button-row'); + // the name of the model is stored in a span beside the .additional div + //let modelName = card.getAttribute('data-name'); + let modelName = card.getAttribute('data-sort-name'); + + // Button to open organizer + let organizerBtnOpen = document.createElement("div"); + organizerBtnOpen.className = "organizer-buttonOpen card-button info"; + organizerBtnOpen.title = "Go To Record"; + organizerBtnOpen.onclick = function (event) { + addRecordClick(event, modelName); + }; + buttonRow.prepend(organizerBtnOpen); + + // we are finished so add the hijack attribute so we know not we don't need to do this card again + card.setAttribute("organizer-hijack", true); + }); + } +}) + +// Switch to Organizer Tab +function switchToOrganizerTab(event, name) { + event.stopPropagation(); + event.preventDefault(); + + var tabs = gradioApp().querySelectorAll('#tab_txt2img, #tab_img2img'); + if (typeof tabs != "undefined" && tabs != null && tabs.length > 0) { + tabs.forEach(tab => { + styleattr = tab.getAttribute('style'); + if (styleattr.includes('block')) { + lastTabName = tab.id.substring(4); + } + }); + } + + + organizerTab.click(); + organizerTab.dispatchEvent(inputEvent); + + var statebox = gradioApp().querySelector("#mo-home-state-box"); + statebox.dispatchEvent(changeEvent); + + var accordion = gradioApp().querySelector("#model_organizer_accordion"); + var labelWrap = accordion.querySelector('.label-wrap'); + + if (!labelWrap.classList.contains('open')) { + labelWrap.click(); + labelWrap.dispatchEvent(inputEvent); + } + + var searchArea = gradioApp().querySelector("#model_organizer_searchbox textarea"); + setTimeout((event) => { + searchArea.value = name; + searchArea.dispatchEvent(inputEvent); + }, 150); +} + +function addRecordClick(event, name) { + switchToOrganizerTab(event, name); + + //Find the card and click add button + + setTimeout((event) => { + var recordButtons = gradioApp().querySelectorAll('#organizer_record_table button.mo-btn.mo-btn-success, #organizer_record_card_grid button.mo-btn.mo-btn-success'); + if (recordButtons.length == 1) { + recordButtons[0].click(); + } + }, 400); + +} + +function fillPrompt(recordid) { + logMo('Loading record info for id: ' + recordid) + const navObj = { + screen: "record_info", + record_info_id: recordid, + token: generateUUID(), + backstack: populateBackstack() + }; + deliverNavObject(navObj) + + var timer = setInterval(() => { + record_data = gradioApp().querySelector('#mo_record_info_nav_box textarea'); + + if (record_data == null) return; + var terminate = false; + var jsdata = record_data.value; + var jsdata = jsdata.replace(/'/g, '"').replace(/"checkpoint": True/mg, '"checkpoint": true').replace(/"checkpoint": False/mg, '"checkpoint": false'); + recordInfo = JSON.parse(jsdata); + if (recordInfo.hasOwnProperty("id") && recordInfo["id"] === recordid) { + var pos = ""; + var neg = ""; + if (recordInfo.hasOwnProperty('positive_prompts')) { + pos = recordInfo['positive_prompts']; + terminate = true; + } + if (recordInfo.hasOwnProperty('negative_prompts')) { + neg = recordInfo['negative_prompts']; + terminate = true; + } + if (recordInfo.hasOwnProperty('checkpoint') && recordInfo['checkpoint']) { + selectCheckpoint(recordInfo['positive_prompts']); + terminate = true; + } else { + if (pos !== "") { + cardClicked(lastTabName, pos, "", false); + } + if (neg !== "") { + cardClicked(lastTabName, "", neg, true); + } + } + } + if (terminate) { + clearInterval(timer); + } + }, 100) + + return [] +} diff --git a/scripts/mo/data/record_utils.py b/scripts/mo/data/record_utils.py index feb22fb..d47440f 100644 --- a/scripts/mo/data/record_utils.py +++ b/scripts/mo/data/record_utils.py @@ -5,7 +5,7 @@ from typing import List, Dict from scripts.mo.data.mapping_utils import create_version_dict from scripts.mo.environment import env from scripts.mo.models import ModelSort, Record, ModelType -from scripts.mo.utils import get_model_files_in_dir, find_info_file +from scripts.mo.utils import get_model_files_in_dir, find_info_file, find_info_json_file def _sort_records(records: List, sort_order: ModelSort, sort_downloaded_first: bool) -> List: @@ -78,7 +78,7 @@ def _create_model_from_info_file(path, info_file_path, model_type): def _create_model_from_local_file(path, model_type): filename = os.path.basename(path) - return Record( + record = Record( id_=None, name=filename, model_type=model_type, @@ -87,6 +87,19 @@ def _create_model_from_local_file(path, model_type): download_filename=filename, download_path=os.path.dirname(path) ) + jsonFile = find_info_json_file(path) + if jsonFile: + try: + jsontxt = open(jsonFile) + jsonobj = json.load(jsontxt) + if ("activation text" in jsonobj) and (env.prefill_pos_prompt()): + record.positive_prompts = jsonobj["activation text"] + if ("negative text" in jsonobj) and (env.prefill_neg_prompt()): + record.negative_prompts = jsonobj["negative text"] + jsontxt.close() + except Exception as ex: + jsontxt.close() + return record def _get_model_type_from_file(path): diff --git a/scripts/mo/environment.py b/scripts/mo/environment.py index c8509be..5d6bfb4 100644 --- a/scripts/mo/environment.py +++ b/scripts/mo/environment.py @@ -62,6 +62,9 @@ class Environment: download_preview: Callable[[], bool] resize_preview: Callable[[], bool] nsfw_blur: Callable[[], bool] + prefill_pos_prompt: Callable[[], bool] + prefill_neg_prompt: Callable[[], bool] + autobind_file: Callable[[], bool] model_path: Callable[[], str] vae_path: Callable[[], str] lora_path: Callable[[], str] diff --git a/scripts/mo/ui_edit.py b/scripts/mo/ui_edit.py index 57c37a4..8d89fef 100644 --- a/scripts/mo/ui_edit.py +++ b/scripts/mo/ui_edit.py @@ -337,7 +337,8 @@ def edit_ui_block(): name_widget = gr.Textbox(label='Name:', value='', max_lines=1, - info='Model title to display (Required)') + info='Model title to display (Required)', + elem_id='model_organizer_edit_name') model_type_widget = gr.Dropdown( [model_type.value for model_type in ModelType], value='', @@ -382,7 +383,8 @@ def edit_ui_block(): location_bind_widget = gr.Dropdown(label='Bind with local file', info='Choose a local file to associate this record with.', - interactive=True) + interactive=True, + elem_id='model_organizer_add_bind') with gr.Accordion(label='Download options', open=False): download_path_widget = gr.Textbox(label='Download Path:', diff --git a/scripts/mo/ui_home.py b/scripts/mo/ui_home.py index 669f434..50b477c 100644 --- a/scripts/mo/ui_home.py +++ b/scripts/mo/ui_home.py @@ -138,7 +138,7 @@ def home_ui_block(): import_export_button = gr.Button('Import/Export') add_button = gr.Button('Add') - with gr.Accordion(label='Display options', open=False): + with gr.Accordion(label='Display options', open=False, elem_id='model_organizer_accordion'): with gr.Group(): sort_box = gr.Dropdown([model_sort.value for model_sort in ModelSort], value=sort_order, @@ -150,7 +150,7 @@ def home_ui_block(): with gr.Group(): search_box = gr.Textbox(label='Search by name', - value=initial_state['query']) + value=initial_state['query'], elem_id='model_organizer_searchbox') model_types_dropdown = gr.Dropdown([model_type.value for model_type in ModelType], value=initial_state['model_types'], label='Model types', diff --git a/scripts/mo/ui_main.py b/scripts/mo/ui_main.py index fb2f5dc..daaa466 100644 --- a/scripts/mo/ui_main.py +++ b/scripts/mo/ui_main.py @@ -10,6 +10,7 @@ from scripts.mo.ui_edit import edit_ui_block from scripts.mo.ui_home import home_ui_block from scripts.mo.ui_import_export import import_export_ui_block from scripts.mo.ui_remove import remove_ui_block +from scripts.mo.utils import get_json_record_data def on_json_box_change(json_state, home_refresh_token): @@ -33,7 +34,8 @@ def on_json_box_change(json_state, home_refresh_token): gr.Textbox.update(value=state['edit_data']), gr.Textbox.update(value=state['remove_record_id']), gr.Textbox.update(value=state['download_info']), - gr.Textbox.update(value=state['filter_state']) + gr.Textbox.update(value=state['filter_state']), + gr.Textbox.update(value=get_json_record_data(state['details_record_info_id'])) ] @@ -73,6 +75,9 @@ def main_ui_block(): else: gr.Row() + #TODO Write record data json into this + details_data_box = gr.Textbox(value='\{\}', label='mo_record_info_nav_box', elem_id='mo_record_info_nav_box', elem_classes='mo-alert-warning', visible=False) + _json_nav_box.change(on_json_box_change, inputs=[_json_nav_box, home_refresh_box], outputs=[home_block, @@ -88,7 +93,8 @@ def main_ui_block(): edit_id_box, remove_id_box, download_id_box, - filter_state_box + filter_state_box, + details_data_box ]) return main_block diff --git a/scripts/mo/ui_navigation.py b/scripts/mo/ui_navigation.py index ec01810..73f3d98 100644 --- a/scripts/mo/ui_navigation.py +++ b/scripts/mo/ui_navigation.py @@ -9,12 +9,13 @@ _REMOVE = 'remove' _DOWNLOAD = 'download' _IMPORT_EXPORT = 'import_export' _DEBUG = 'debug' +_RECORD_INFO = 'record_info' _NODE_SCREEN = 'screen' _NODE_RECORD_ID = 'record_id' _NODE_PREFILLED_JSON = 'prefilled_json' _NODE_GROUP = 'group' - +_NODE_RECORD_INFO_ID = 'record_info_id' def navigate_home() -> str: return '{}' @@ -81,7 +82,8 @@ def get_nav_state(json_nav) -> dict: 'edit_data': {}, 'remove_record_id': '', 'download_info': '', - 'filter_state': {} + 'filter_state': {}, + 'details_record_info_id': '' } if nav_dict.get(_NODE_SCREEN) is None: @@ -124,6 +126,9 @@ def get_nav_state(json_nav) -> dict: state['filter_state'] = nav_dict['filter_state'] elif nav_dict[_NODE_SCREEN] == _DEBUG: state['is_debug_visible'] = True + elif nav_dict[_NODE_SCREEN] == _RECORD_INFO: + state['details_record_info_id'] = nav_dict[_NODE_RECORD_INFO_ID] + state['is_home_visible'] = True return state diff --git a/scripts/mo/ui_styled_html.py b/scripts/mo/ui_styled_html.py index ee7f0cc..99bcda2 100644 --- a/scripts/mo/ui_styled_html.py +++ b/scripts/mo/ui_styled_html.py @@ -107,7 +107,7 @@ def _no_preview_image_url() -> str: def records_table(records: List) -> str: - table_html = '
' + table_html = '
' table_html += '
' table_html += '
Preview
' table_html += '
Type
' @@ -115,7 +115,9 @@ def records_table(records: List) -> str: table_html += '
Description
' table_html += '
Actions
' table_html += '
' + nsfw_blur = env.nsfw_blur() for record in records: + contains_nsfw = any('nsfw' in group.lower() for group in record.groups) and nsfw_blur name = html.escape(record.name) type_ = record.model_type.value preview_url = get_best_preview_url(record) @@ -126,9 +128,31 @@ def records_table(records: List) -> str: # Add preview URL column table_html += '
' - table_html += f'' + f' onerror="this.onerror=null; this.src=\'{_no_preview_image_url()}\';"/' + if not isLocalFileRecord: + img += f'onclick="fillPrompt({record.id_})"' + img += '>' + table_html += img + # table_html += f' str: # Add name column table_html += f'
' - table_html += f'' + table_html += f'' table_html += '
' # Add description column @@ -152,23 +176,23 @@ def records_table(records: List) -> str: json_record = html.escape(json.dumps(map_record_to_dict(record))) table_html += '
' + f'onclick="navigateEditPrefilled(\'{json_record}\', event)">Add
' table_html += '
' + f'onclick="navigateRemove(\'{record.location}\', event)">Remove
' else: table_html += '
' + f'onclick="navigateDetails(\'{record.id_}\', event)">Details
' if record.is_download_possible(): table_html += '
' + f'onclick="navigateDownloadRecord(\'{record.id_}\', event)">Download
' table_html += '
' + f'onclick="navigateEdit(\'{record.id_}\', event)">Edit
' table_html += '
' + f'onclick="navigateRemove(\'{record.id_}\', event)">Remove
' table_html += '
' # Close row @@ -331,13 +355,22 @@ def record_details(record: Record) -> str: def records_cards(records: List) -> str: - content = '
' + content = '
' nsfw_blur = env.nsfw_blur() for record in records: contains_nsfw = any('nsfw' in group.lower() for group in record.groups) and nsfw_blur - content += f'
' + + isLocalFileRecord = record.is_local_file_record() + + # Taken from extra networks cards. + cardStr = f'
str: content += '
' content += '
' - if record.is_local_file_record(): + if isLocalFileRecord: json_record = html.escape(json.dumps(map_record_to_dict(record))) content += '
' + f'onclick="navigateEditPrefilled(\'{json_record}\', event)">Add
' location = record.location.replace("\\", "\\\\") content += '
' + f'onclick="navigateRemove(\'{location}\', event)">Remove
' else: content += '
' + f'onclick="navigateDetails(\'{record.id_}\', event)">Details
' if record.is_download_possible(): content += '
' + f'onclick="navigateDownloadRecord(\'{record.id_}\', event)">Download
' content += '
' + f'onclick="navigateEdit(\'{record.id_}\', event)">Edit
' content += '
' + f'onclick="navigateRemove(\'{record.id_}\', event)">Remove
' content += '
' content += '
' diff --git a/scripts/mo/utils.py b/scripts/mo/utils.py index 0b56970..421e1ae 100644 --- a/scripts/mo/utils.py +++ b/scripts/mo/utils.py @@ -4,13 +4,18 @@ import json import os import re import urllib.parse +import sys +sys.path.append('extensions-builtin/Lora') +import networks + from typing import List from PIL import Image from PIL.PngImagePlugin import PngInfo from scripts.mo.environment import env -from scripts.mo.models import Record +from scripts.mo.models import Record, ModelType +from modules import sd_hijack _HASH_CACHE_FILENAME = 'hash_cache.json' @@ -260,3 +265,81 @@ def get_best_preview_url(record: Record) -> str: else: return link_preview(preview_path) return record.preview_url + +def find_info_json_file(model_file_path): + """ + Looks for model info json file. + :param model_file_path: path to model file. + :return: path to model info file if exists, None otherwise. + """ + if model_file_path: + filename_no_ext = get_model_filename_without_extension(model_file_path) + path = os.path.join(os.path.dirname(model_file_path), filename_no_ext) + + file = path + ".json" + if os.path.isfile(file): + return file + + return None + +def get_json_record_data(id): + result = {} + if (id != None) and (isinstance(id, int)) and (id > 0): + record = env.storage.get_record_by_id(id) + + pos = '' if record is None else record.positive_prompts + neg = '' if record is None else record.negative_prompts + isCheckPoint = False + if (record.model_type == ModelType.CHECKPOINT): + isCheckPoint = True + pos = record.name + + + + + elif(record.model_type == ModelType.LORA or record.model_type == ModelType.LYCORIS): + lora_on_disk = networks.available_networks.get(get_model_filename_without_extension(record.name)) + if lora_on_disk is None: + return {} + alias = lora_on_disk.get_alias() + + activation_text = record.positive_prompts + preferred_weight = 1.0 #item["user_metadata"].get("preferred weight", 0.0) + pos = f'' + + if activation_text: + pos += " " + activation_text + + negative_prompt = record.negative_prompts + if negative_prompt: + neg = negative_prompt + + + + elif(record.model_type == ModelType.HYPER_NETWORK): + pos = f'' + + + + elif(record.model_type == ModelType.EMBEDDING): + embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(get_model_filename_without_extension(record.name)) + if embedding is None: + return {} + if pos: + pos = embedding.name + if neg: + neg = embedding.name + + + elif(record.model_type == ModelType.VAE or record.model_type == ModelType.OTHER): + return {} + + result = { + "id": id, + "positive_prompts": pos, + "negative_prompts": neg, + "checkpoint": isCheckPoint + } + + return json.loads(json.dumps(result)); + diff --git a/scripts/model_organizer.py b/scripts/model_organizer.py index 92238e0..354449c 100644 --- a/scripts/model_organizer.py +++ b/scripts/model_organizer.py @@ -110,6 +110,25 @@ env.nsfw_blur = ( else True ) +env.prefill_pos_prompt = ( + lambda: shared.opts.mo_prefill_pos_prompt + if hasattr(shared.opts, 'mo_prefill_pos_prompt') + else True +) + +env.prefill_neg_prompt = ( + lambda: shared.opts.mo_prefill_neg_prompt + if hasattr(shared.opts, 'mo_prefill_neg_prompt') + else True +) + +env.autobind_file = ( + lambda: shared.opts.mo_autobind_file + if hasattr(shared.opts, 'mo_autobind_file') + else True +) + + env.model_path = ( lambda: shared.opts.mo_model_path if hasattr(shared.opts, 'mo_model_path') and shared.opts.mo_model_path @@ -170,6 +189,9 @@ def on_ui_settings(): 'mo_download_preview': OptionInfo(True, 'Download Preview'), 'mo_resize_preview': OptionInfo(True, 'Resize Preview'), 'mo_nsfw_blur': OptionInfo(True, 'Blur NSFW Previews (models with "nsfw" tag)'), + 'mo_prefill_pos_prompt': OptionInfo(True, 'When creating a record based on local file, automatically import the added positive prompts'), + 'mo_prefill_neg_prompt': OptionInfo(True, 'When creating a record based on local file, automatically import the added negative prompts'), + 'mo_autobind_file': OptionInfo(True, 'Automatically bind record to local file'), } dir_opts = { diff --git a/styles/styles.css b/styles/styles.css index 5e26717..317db54 100644 --- a/styles/styles.css +++ b/styles/styles.css @@ -641,4 +641,12 @@ justify-content: space-between; align-items: center; margin-top: 8px; +} +/* Extra network tab integration*/ +.extra-network-cards .card .organizer-buttonOpen::before { + content: "🔎︎" +} + +.extra-network-cards .card .organizer-buttonNew::before { + content: "+" } \ No newline at end of file