Integration of Extra Networks into the extension

Added functionality to inside the organizer
Ability to "import" extra network json files
Ability to click on the card to add prompt to txt2img or img2img
pull/75/head
Learwin 2024-01-21 18:04:07 +01:00
parent acedea8c6d
commit 31ce74bbac
11 changed files with 397 additions and 48 deletions

View File

@ -345,7 +345,9 @@ function navigateBack() {
return []
}
function navigateDetails(id) {
function navigateDetails(id, event) {
event.stopPropagation();
event.preventDefault();
logMo('Navigate details screen for id: ' + id)
const navObj = {
screen: "details",
@ -380,7 +382,7 @@ function navigateImportExport(filter_state) {
return []
}
function navigateDebug() {
function navigateDebug(event) {
logMo('Navigate debug screen')
const navObj = {
screen: "debug",
@ -391,7 +393,9 @@ function navigateDebug() {
return []
}
function navigateEdit(id) {
function navigateEdit(id, event) {
event.stopPropagation();
event.preventDefault();
logMo('Navigate edit screen for id: ' + id)
const navObj = {
screen: "edit",
@ -403,7 +407,9 @@ function navigateEdit(id) {
return []
}
function navigateEditPrefilled(json_data) {
function navigateEditPrefilled(json_data, event) {
event.stopPropagation();
event.preventDefault();
logMo('Navigate edit screen for prefilled json: ' + json_data)
const navObj = {
screen: "edit",
@ -411,11 +417,34 @@ function navigateEditPrefilled(json_data) {
token: generateUUID(),
backstack: populateBackstack()
};
setTimeout((event) => {
//Get config options
var json_elem = gradioApp().getElementById('settings_json');
if (json_elem == null) return;
var textarea = json_elem.querySelector('textarea');
var jsdata = textarea.value;
opts = JSON.parse(jsdata);
if (opts['mo_autobind_file']) {
//setTimeout((event) => {
var bind = gradioApp().querySelector('#model_organizer_add_bind input');
var modelName = gradioApp().querySelector('#model_organizer_edit_name input');
bind.value = modelName.value;
}
}, 300);
deliverNavObject(navObj)
return []
}
function navigateDownloadRecord(id) {
function navigateDownloadRecord(id, event) {
event.stopPropagation();
event.preventDefault();
logMo('Navigate download screen for id: ' + id)
const navObj = {
screen: "download",
@ -451,7 +480,9 @@ function navigateDownloadGroup(groupName) {
return []
}
function navigateRemove(id) {
function navigateRemove(id, event) {
event.stopPropagation();
event.preventDefault();
logMo('Navigate removal screen for id: ' + id)
const navObj = {
screen: "remove",
@ -573,3 +604,146 @@ onUiLoaded(function () {
installCardsSize(size[0], size[1])
})
})
let organizerTab = null;
let lastTabName = 'txt2img';
const inputEvent = new Event('input', { 'bubbles': true, "composed": true });
const changeEvent = new Event('change', { 'bubbles': true, "composed": true });
// Extra networks tab integration
// Huge thanks to https://github.com/CurtisDS/sd-model-preview-xd/tree/main for how to do this
onUiUpdate(function () {
// get the organizer tab
let tabs = gradioApp().querySelectorAll("#tabs > div:first-of-type button");
if (typeof tabs != "undefined" && tabs != null && tabs.length > 0) {
tabs.forEach(tab => {
if (tab.innerText == "Model Organizer") {
organizerTab = tab;
}
});
}
// Get
let thumbCards = gradioApp().querySelectorAll("#txt2img_extra_tabs .card:not([organizer-hijack]), #img2img_extra_tabs .card:not([organizer-hijack])");
if (typeof thumbCards != "undefined" && thumbCards != null && thumbCards.length > 0) {
thumbCards.forEach(card => {
let buttonRow = card.querySelector('.button-row');
// the name of the model is stored in a span beside the .additional div
//let modelName = card.getAttribute('data-name');
let modelName = card.getAttribute('data-sort-name');
// Button to open organizer
let organizerBtnOpen = document.createElement("div");
organizerBtnOpen.className = "organizer-buttonOpen card-button info";
organizerBtnOpen.title = "Go To Record";
organizerBtnOpen.onclick = function (event) {
addRecordClick(event, modelName);
};
buttonRow.prepend(organizerBtnOpen);
// we are finished so add the hijack attribute so we know not we don't need to do this card again
card.setAttribute("organizer-hijack", true);
});
}
})
// Switch to Organizer Tab
function switchToOrganizerTab(event, name) {
event.stopPropagation();
event.preventDefault();
var tabs = gradioApp().querySelectorAll('#tab_txt2img, #tab_img2img');
if (typeof tabs != "undefined" && tabs != null && tabs.length > 0) {
tabs.forEach(tab => {
styleattr = tab.getAttribute('style');
if (styleattr.includes('block')) {
lastTabName = tab.id.substring(4);
}
});
}
organizerTab.click();
organizerTab.dispatchEvent(inputEvent);
var statebox = gradioApp().querySelector("#mo-home-state-box");
statebox.dispatchEvent(changeEvent);
var accordion = gradioApp().querySelector("#model_organizer_accordion");
var labelWrap = accordion.querySelector('.label-wrap');
if (!labelWrap.classList.contains('open')) {
labelWrap.click();
labelWrap.dispatchEvent(inputEvent);
}
var searchArea = gradioApp().querySelector("#model_organizer_searchbox textarea");
setTimeout((event) => {
searchArea.value = name;
searchArea.dispatchEvent(inputEvent);
}, 150);
}
function addRecordClick(event, name) {
switchToOrganizerTab(event, name);
//Find the card and click add button
setTimeout((event) => {
var recordButtons = gradioApp().querySelectorAll('#organizer_record_table button.mo-btn.mo-btn-success, #organizer_record_card_grid button.mo-btn.mo-btn-success');
if (recordButtons.length == 1) {
recordButtons[0].click();
}
}, 400);
}
function fillPrompt(recordid) {
logMo('Loading record info for id: ' + recordid)
const navObj = {
screen: "record_info",
record_info_id: recordid,
token: generateUUID(),
backstack: populateBackstack()
};
deliverNavObject(navObj)
var timer = setInterval(() => {
record_data = gradioApp().querySelector('#mo_record_info_nav_box textarea');
if (record_data == null) return;
var terminate = false;
var jsdata = record_data.value;
var jsdata = jsdata.replace(/'/g, '"').replace(/"checkpoint": True/mg, '"checkpoint": true').replace(/"checkpoint": False/mg, '"checkpoint": false');
recordInfo = JSON.parse(jsdata);
if (recordInfo.hasOwnProperty("id") && recordInfo["id"] === recordid) {
var pos = "";
var neg = "";
if (recordInfo.hasOwnProperty('positive_prompts')) {
pos = recordInfo['positive_prompts'];
terminate = true;
}
if (recordInfo.hasOwnProperty('negative_prompts')) {
neg = recordInfo['negative_prompts'];
terminate = true;
}
if (recordInfo.hasOwnProperty('checkpoint') && recordInfo['checkpoint']) {
selectCheckpoint(recordInfo['positive_prompts']);
terminate = true;
} else {
if (pos !== "") {
cardClicked(lastTabName, pos, "", false);
}
if (neg !== "") {
cardClicked(lastTabName, "", neg, true);
}
}
}
if (terminate) {
clearInterval(timer);
}
}, 100)
return []
}

View File

@ -5,7 +5,7 @@ from typing import List, Dict
from scripts.mo.data.mapping_utils import create_version_dict
from scripts.mo.environment import env
from scripts.mo.models import ModelSort, Record, ModelType
from scripts.mo.utils import get_model_files_in_dir, find_info_file
from scripts.mo.utils import get_model_files_in_dir, find_info_file, find_info_json_file
def _sort_records(records: List, sort_order: ModelSort, sort_downloaded_first: bool) -> List:
@ -78,7 +78,7 @@ def _create_model_from_info_file(path, info_file_path, model_type):
def _create_model_from_local_file(path, model_type):
filename = os.path.basename(path)
return Record(
record = Record(
id_=None,
name=filename,
model_type=model_type,
@ -87,6 +87,19 @@ def _create_model_from_local_file(path, model_type):
download_filename=filename,
download_path=os.path.dirname(path)
)
jsonFile = find_info_json_file(path)
if jsonFile:
try:
jsontxt = open(jsonFile)
jsonobj = json.load(jsontxt)
if ("activation text" in jsonobj) and (env.prefill_pos_prompt()):
record.positive_prompts = jsonobj["activation text"]
if ("negative text" in jsonobj) and (env.prefill_neg_prompt()):
record.negative_prompts = jsonobj["negative text"]
jsontxt.close()
except Exception as ex:
jsontxt.close()
return record
def _get_model_type_from_file(path):

View File

@ -62,6 +62,9 @@ class Environment:
download_preview: Callable[[], bool]
resize_preview: Callable[[], bool]
nsfw_blur: Callable[[], bool]
prefill_pos_prompt: Callable[[], bool]
prefill_neg_prompt: Callable[[], bool]
autobind_file: Callable[[], bool]
model_path: Callable[[], str]
vae_path: Callable[[], str]
lora_path: Callable[[], str]

View File

@ -337,7 +337,8 @@ def edit_ui_block():
name_widget = gr.Textbox(label='Name:',
value='',
max_lines=1,
info='Model title to display (Required)')
info='Model title to display (Required)',
elem_id='model_organizer_edit_name')
model_type_widget = gr.Dropdown(
[model_type.value for model_type in ModelType],
value='',
@ -382,7 +383,8 @@ def edit_ui_block():
location_bind_widget = gr.Dropdown(label='Bind with local file',
info='Choose a local file to associate this record with.',
interactive=True)
interactive=True,
elem_id='model_organizer_add_bind')
with gr.Accordion(label='Download options', open=False):
download_path_widget = gr.Textbox(label='Download Path:',

View File

@ -138,7 +138,7 @@ def home_ui_block():
import_export_button = gr.Button('Import/Export')
add_button = gr.Button('Add')
with gr.Accordion(label='Display options', open=False):
with gr.Accordion(label='Display options', open=False, elem_id='model_organizer_accordion'):
with gr.Group():
sort_box = gr.Dropdown([model_sort.value for model_sort in ModelSort],
value=sort_order,
@ -150,7 +150,7 @@ def home_ui_block():
with gr.Group():
search_box = gr.Textbox(label='Search by name',
value=initial_state['query'])
value=initial_state['query'], elem_id='model_organizer_searchbox')
model_types_dropdown = gr.Dropdown([model_type.value for model_type in ModelType],
value=initial_state['model_types'],
label='Model types',

View File

@ -10,6 +10,7 @@ from scripts.mo.ui_edit import edit_ui_block
from scripts.mo.ui_home import home_ui_block
from scripts.mo.ui_import_export import import_export_ui_block
from scripts.mo.ui_remove import remove_ui_block
from scripts.mo.utils import get_json_record_data
def on_json_box_change(json_state, home_refresh_token):
@ -33,7 +34,8 @@ def on_json_box_change(json_state, home_refresh_token):
gr.Textbox.update(value=state['edit_data']),
gr.Textbox.update(value=state['remove_record_id']),
gr.Textbox.update(value=state['download_info']),
gr.Textbox.update(value=state['filter_state'])
gr.Textbox.update(value=state['filter_state']),
gr.Textbox.update(value=get_json_record_data(state['details_record_info_id']))
]
@ -73,6 +75,9 @@ def main_ui_block():
else:
gr.Row()
#TODO Write record data json into this
details_data_box = gr.Textbox(value='\{\}', label='mo_record_info_nav_box', elem_id='mo_record_info_nav_box', elem_classes='mo-alert-warning', visible=False)
_json_nav_box.change(on_json_box_change,
inputs=[_json_nav_box, home_refresh_box],
outputs=[home_block,
@ -88,7 +93,8 @@ def main_ui_block():
edit_id_box,
remove_id_box,
download_id_box,
filter_state_box
filter_state_box,
details_data_box
])
return main_block

View File

@ -9,12 +9,13 @@ _REMOVE = 'remove'
_DOWNLOAD = 'download'
_IMPORT_EXPORT = 'import_export'
_DEBUG = 'debug'
_RECORD_INFO = 'record_info'
_NODE_SCREEN = 'screen'
_NODE_RECORD_ID = 'record_id'
_NODE_PREFILLED_JSON = 'prefilled_json'
_NODE_GROUP = 'group'
_NODE_RECORD_INFO_ID = 'record_info_id'
def navigate_home() -> str:
return '{}'
@ -81,7 +82,8 @@ def get_nav_state(json_nav) -> dict:
'edit_data': {},
'remove_record_id': '',
'download_info': '',
'filter_state': {}
'filter_state': {},
'details_record_info_id': ''
}
if nav_dict.get(_NODE_SCREEN) is None:
@ -124,6 +126,9 @@ def get_nav_state(json_nav) -> dict:
state['filter_state'] = nav_dict['filter_state']
elif nav_dict[_NODE_SCREEN] == _DEBUG:
state['is_debug_visible'] = True
elif nav_dict[_NODE_SCREEN] == _RECORD_INFO:
state['details_record_info_id'] = nav_dict[_NODE_RECORD_INFO_ID]
state['is_home_visible'] = True
return state

View File

@ -107,7 +107,7 @@ def _no_preview_image_url() -> str:
def records_table(records: List) -> str:
table_html = '<div class="mo-container">'
table_html = '<div id="organizer_record_table" class="mo-container">'
table_html += '<div class="mo-row mo-row-header">'
table_html += '<div class="mo-col mo-col-preview"><span class="mo-text-header">Preview</span></div>'
table_html += '<div class="mo-col mo-col-type"><span class="mo-text-header">Type</span></div>'
@ -115,7 +115,9 @@ def records_table(records: List) -> str:
table_html += '<div class="mo-col mo-col-description"><span class="mo-text-header">Description</span></div>'
table_html += '<div class="mo-col mo-col-actions"><span class="mo-text-header">Actions</span></div>'
table_html += '</div>'
nsfw_blur = env.nsfw_blur()
for record in records:
contains_nsfw = any('nsfw' in group.lower() for group in record.groups) and nsfw_blur
name = html.escape(record.name)
type_ = record.model_type.value
preview_url = get_best_preview_url(record)
@ -126,9 +128,31 @@ def records_table(records: List) -> str:
# Add preview URL column
table_html += '<div class="mo-col mo-col-preview">'
table_html += f'<img class="mo-preview-image" src="{preview_url}" ' \
###
isLocalFileRecord = record.is_local_file_record()
# Taken from extra networks cards.
# cardStr = f'<div class="mo-card {_model_card_type_css_class(record.model_type)} {"blur" if contains_nsfw else ""}"'
# if not isLocalFileRecord:
# cardStr += f'onclick="fillPrompt({record.id_})"'
# cardStr += '>'
# content += cardStr
###
img = f'<img class="mo-preview-image" src="{preview_url}" ' \
f'alt="Preview image"' \
f' onerror="this.onerror=null; this.src=\'{_no_preview_image_url()}\';"/>'
f' onerror="this.onerror=null; this.src=\'{_no_preview_image_url()}\';"/'
if not isLocalFileRecord:
img += f'onclick="fillPrompt({record.id_})"'
img += '>'
table_html += img
# table_html += f'<img class="mo-preview-image" src="{preview_url}" ' \
# f'alt="Preview image"' \
# f' onerror="this.onerror=null; this.src=\'{_no_preview_image_url()}\';"/' \
# f'onclick="fillPrompt({record.id_})"'
table_html += '</div>'
# Add type column
@ -137,7 +161,7 @@ def records_table(records: List) -> str:
# Add name column
table_html += f'<div class="mo-col mo-col-name">'
table_html += f'<button class="mo-button-name" onclick="navigateDetails(\'{record.id_}\')">{name}</button>'
table_html += f'<button class="mo-button-name" onclick="navigateDetails(\'{record.id_}\', event)">{name}</button>'
table_html += '</div>'
# Add description column
@ -152,23 +176,23 @@ def records_table(records: List) -> str:
json_record = html.escape(json.dumps(map_record_to_dict(record)))
table_html += '<button type="button" class="mo-btn mo-btn-success" ' \
f'onclick="navigateEditPrefilled(\'{json_record}\')">Add</button><br>'
f'onclick="navigateEditPrefilled(\'{json_record}\', event)">Add</button><br>'
table_html += '<button type="button" class="mo-btn mo-btn-danger" ' \
f'onclick="navigateRemove(\'{record.location}\')">Remove</button><br>'
f'onclick="navigateRemove(\'{record.location}\', event)">Remove</button><br>'
else:
table_html += '<button type="button" class="mo-btn mo-btn-success" ' \
f'onclick="navigateDetails(\'{record.id_}\')">Details</button><br>'
f'onclick="navigateDetails(\'{record.id_}\', event)">Details</button><br>'
if record.is_download_possible():
table_html += '<button type="button" class="mo-btn mo-btn-primary" ' \
f'onclick="navigateDownloadRecord(\'{record.id_}\')">Download</button><br>'
f'onclick="navigateDownloadRecord(\'{record.id_}\', event)">Download</button><br>'
table_html += '<button type="button" class="mo-btn mo-btn-warning" ' \
f'onclick="navigateEdit(\'{record.id_}\')">Edit</button><br>'
f'onclick="navigateEdit(\'{record.id_}\', event)">Edit</button><br>'
table_html += '<button type="button" class="mo-btn mo-btn-danger" ' \
f'onclick="navigateRemove(\'{record.id_}\')">Remove</button><br>'
f'onclick="navigateRemove(\'{record.id_}\', event)">Remove</button><br>'
table_html += '</div>'
# Close row
@ -331,13 +355,22 @@ def record_details(record: Record) -> str:
def records_cards(records: List) -> str:
content = '<div class="mo-card-grid">'
content = '<div id="organizer_record_card_grid" class="mo-card-grid">'
nsfw_blur = env.nsfw_blur()
for record in records:
contains_nsfw = any('nsfw' in group.lower() for group in record.groups) and nsfw_blur
content += f'<div class="mo-card {_model_card_type_css_class(record.model_type)} {"blur" if contains_nsfw else ""}">'
isLocalFileRecord = record.is_local_file_record()
# Taken from extra networks cards.
cardStr = f'<div class="mo-card {_model_card_type_css_class(record.model_type)} {"blur" if contains_nsfw else ""}"'
if not isLocalFileRecord:
cardStr += f'onclick="fillPrompt({record.id_})"'
cardStr += '>'
content += cardStr
preview_url = get_best_preview_url(record)
content += f'<img src="{preview_url}" alt="Preview Image" ' \
@ -353,28 +386,28 @@ def records_cards(records: List) -> str:
content += '<div class="mo-card-hover">'
content += '<div class="mo-card-hover-buttons">'
if record.is_local_file_record():
if isLocalFileRecord:
json_record = html.escape(json.dumps(map_record_to_dict(record)))
content += '<button type="button" class="mo-btn mo-btn-success" ' \
f'onclick="navigateEditPrefilled(\'{json_record}\')">Add</button><br>'
f'onclick="navigateEditPrefilled(\'{json_record}\', event)">Add</button><br>'
location = record.location.replace("\\", "\\\\")
content += '<button type="button" class="mo-btn mo-btn-danger" ' \
f'onclick="navigateRemove(\'{location}\')">Remove</button><br>'
f'onclick="navigateRemove(\'{location}\', event)">Remove</button><br>'
else:
content += '<button type="button" class="mo-btn mo-btn-success" ' \
f'onclick="navigateDetails(\'{record.id_}\')">Details</button><br>'
f'onclick="navigateDetails(\'{record.id_}\', event)">Details</button><br>'
if record.is_download_possible():
content += '<button type="button" class="mo-btn mo-btn-primary" ' \
f'onclick="navigateDownloadRecord(\'{record.id_}\')">Download</button><br>'
f'onclick="navigateDownloadRecord(\'{record.id_}\', event)">Download</button><br>'
content += '<button type="button" class="mo-btn mo-btn-warning" ' \
f'onclick="navigateEdit(\'{record.id_}\')">Edit</button><br>'
f'onclick="navigateEdit(\'{record.id_}\', event)">Edit</button><br>'
content += '<button type="button" class="mo-btn mo-btn-danger" ' \
f'onclick="navigateRemove(\'{record.id_}\')">Remove</button><br>'
f'onclick="navigateRemove(\'{record.id_}\', event)">Remove</button><br>'
content += '</div>'
content += '</div>'

View File

@ -4,13 +4,18 @@ import json
import os
import re
import urllib.parse
import sys
sys.path.append('extensions-builtin/Lora')
import networks
from typing import List
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from scripts.mo.environment import env
from scripts.mo.models import Record
from scripts.mo.models import Record, ModelType
from modules import sd_hijack
_HASH_CACHE_FILENAME = 'hash_cache.json'
@ -260,3 +265,81 @@ def get_best_preview_url(record: Record) -> str:
else:
return link_preview(preview_path)
return record.preview_url
def find_info_json_file(model_file_path):
"""
Looks for model info json file.
:param model_file_path: path to model file.
:return: path to model info file if exists, None otherwise.
"""
if model_file_path:
filename_no_ext = get_model_filename_without_extension(model_file_path)
path = os.path.join(os.path.dirname(model_file_path), filename_no_ext)
file = path + ".json"
if os.path.isfile(file):
return file
return None
def get_json_record_data(id):
result = {}
if (id != None) and (isinstance(id, int)) and (id > 0):
record = env.storage.get_record_by_id(id)
pos = '' if record is None else record.positive_prompts
neg = '' if record is None else record.negative_prompts
isCheckPoint = False
if (record.model_type == ModelType.CHECKPOINT):
isCheckPoint = True
pos = record.name
elif(record.model_type == ModelType.LORA or record.model_type == ModelType.LYCORIS):
lora_on_disk = networks.available_networks.get(get_model_filename_without_extension(record.name))
if lora_on_disk is None:
return {}
alias = lora_on_disk.get_alias()
activation_text = record.positive_prompts
preferred_weight = 1.0 #item["user_metadata"].get("preferred weight", 0.0)
pos = f'<lora:{alias}:' + (str(preferred_weight) if preferred_weight else '1') + '>'
if activation_text:
pos += " " + activation_text
negative_prompt = record.negative_prompts
if negative_prompt:
neg = negative_prompt
elif(record.model_type == ModelType.HYPER_NETWORK):
pos = f'<hypernet:{get_model_filename_without_extension(record.name)}:1>'
elif(record.model_type == ModelType.EMBEDDING):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(get_model_filename_without_extension(record.name))
if embedding is None:
return {}
if pos:
pos = embedding.name
if neg:
neg = embedding.name
elif(record.model_type == ModelType.VAE or record.model_type == ModelType.OTHER):
return {}
result = {
"id": id,
"positive_prompts": pos,
"negative_prompts": neg,
"checkpoint": isCheckPoint
}
return json.loads(json.dumps(result));

View File

@ -110,6 +110,25 @@ env.nsfw_blur = (
else True
)
env.prefill_pos_prompt = (
lambda: shared.opts.mo_prefill_pos_prompt
if hasattr(shared.opts, 'mo_prefill_pos_prompt')
else True
)
env.prefill_neg_prompt = (
lambda: shared.opts.mo_prefill_neg_prompt
if hasattr(shared.opts, 'mo_prefill_neg_prompt')
else True
)
env.autobind_file = (
lambda: shared.opts.mo_autobind_file
if hasattr(shared.opts, 'mo_autobind_file')
else True
)
env.model_path = (
lambda: shared.opts.mo_model_path
if hasattr(shared.opts, 'mo_model_path') and shared.opts.mo_model_path
@ -170,6 +189,9 @@ def on_ui_settings():
'mo_download_preview': OptionInfo(True, 'Download Preview'),
'mo_resize_preview': OptionInfo(True, 'Resize Preview'),
'mo_nsfw_blur': OptionInfo(True, 'Blur NSFW Previews (models with "nsfw" tag)'),
'mo_prefill_pos_prompt': OptionInfo(True, 'When creating a record based on local file, automatically import the added positive prompts'),
'mo_prefill_neg_prompt': OptionInfo(True, 'When creating a record based on local file, automatically import the added negative prompts'),
'mo_autobind_file': OptionInfo(True, 'Automatically bind record to local file'),
}
dir_opts = {

View File

@ -642,3 +642,11 @@
align-items: center;
margin-top: 8px;
}
/* Extra network tab integration*/
.extra-network-cards .card .organizer-buttonOpen::before {
content: "🔎︎"
}
.extra-network-cards .card .organizer-buttonNew::before {
content: "+"
}