diff --git a/javascript/state.core.js b/javascript/state.core.js index ac2e7ef..aab7af8 100644 --- a/javascript/state.core.js +++ b/javascript/state.core.js @@ -31,6 +31,10 @@ state.core = (function () { 'resize_mode': 'resize_mode', }; + const MULTI_SELECTS = { + 'styles': 'styles' + }; + let store = null; function hasSetting(id, tab) { @@ -75,6 +79,14 @@ state.core = (function () { }); } + for (const [settingId, element] of Object.entries(MULTI_SELECTS)) { + TABS.forEach(tab => { + if (config.hasSetting(settingId, tab)) { + handleSavedMultiSelects(`${tab}_${element}`); + } + }); + } + handleExtensions(config); } @@ -177,6 +189,52 @@ state.core = (function () { }); } + function handleSavedMultiSelects(id) { + + const select = gradioApp().querySelector(`#${id} .items-center.relative`); + + try { + let value = store.get(id); + + if (value) { + + value = value.split(','); + + if (value.length) { + + let input = select.querySelector('input'); + + let selectOption = function () { + if (! value.length) { + state.utils.triggerMouseEvent(input, 'blur'); + return; + } + let option = value.pop(); + state.utils.triggerMouseEvent(input); + setTimeout(() => { + let items = Array.from(select.parentNode.querySelectorAll('ul li')); + items.forEach(li => { + if (li.lastChild.wholeText.trim() === option) { + state.utils.triggerMouseEvent(li, 'mousedown'); + return false; + } + }); + setTimeout(selectOption, 100); + }, 100); + } + selectOption(); + } + } + } catch (error) { + console.error('[state]: Error:', error); + } + + state.utils.onContentChange(select, function (el) { + const selected = Array.from(el.querySelectorAll('.token > span')).map(item => item.textContent); + store.set(id, selected); + }); + } + function handleExtensions(config) { if (config['state_extensions']) { config['state_extensions'].forEach(function (ext) { diff --git a/javascript/state.utils.js b/javascript/state.utils.js index e2a1f8f..7ce5756 100644 --- a/javascript/state.utils.js +++ b/javascript/state.utils.js @@ -9,6 +9,18 @@ state.utils = { element.dispatchEvent(new Event(event.trim())); return element; }, + triggerMouseEvent: function triggerMouseEvent(element, event) { + if (! element) { + return; + } + event = event || 'click'; + element.dispatchEvent(new MouseEvent(event, { + view: window, + bubbles: true, + cancelable: true, + })); + return element; + }, setValue: function setValue(element, value, event) { switch (element.type) { case 'checkbox': @@ -28,6 +40,20 @@ state.utils = { this.triggerEvent(element, event); } }, + onContentChange: function onContentChange(targetNode, func) { + const observer = new MutationObserver((mutationsList, observer) => { + for (const mutation of mutationsList) { + if (mutation.type === 'childList') { + func(targetNode); + } + } + }); + observer.observe(targetNode, { + childList: true, + characterData: true, + subtree: true + }); + }, txtToId: function txtToId(txt) { return txt.split(' ').join('-').toLowerCase(); }, diff --git a/scripts/state_settings.py b/scripts/state_settings.py index e333e62..9460b46 100644 --- a/scripts/state_settings.py +++ b/scripts/state_settings.py @@ -17,6 +17,7 @@ def on_ui_settings(): "choices": [ "prompt", "negative_prompt", + "styles", "sampling", "sampling_steps", "width", @@ -41,6 +42,7 @@ def on_ui_settings(): "choices": [ "prompt", "negative_prompt", + "styles", "sampling", "resize_mode", "sampling_steps",