Rework controlnet integration after Gradio upgrade, #100. #120

controlnet
AlUlkesh 2023-04-03 14:55:44 +02:00
parent db40a6e925
commit 7646dd813c
2 changed files with 121 additions and 32 deletions

View File

@ -178,14 +178,36 @@ async function image_browser_turnpage(tab_base_tag) {
await image_browser_unlock() await image_browser_unlock()
} }
function image_browser_gototab(tabname, tabsId = "tabs") { async function image_browser_gototab(tabname) {
Array.from( await image_browser_lock("image_browser_gototab")
gradioApp().querySelectorAll(`#${tabsId} > div:first-child button`)
).forEach((button) => { tabNav = gradioApp().querySelector(".tab-nav")
if (button.textContent.trim() === tabname) { const tabNavChildren = tabNav.children
button.click() let tabNavButtonNum
} for (let i = 0; i < tabNavChildren.length; i++) {
}) if (tabNavChildren[i].tagName === "BUTTON" && tabNavChildren[i].textContent.trim() === tabname) {
tabNavButtonNum = i
break
}
}
let tabNavButton = tabNavChildren[tabNavButtonNum]
tabNavButton.click()
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
await image_browser_delay(100)
while (!tabNavButton.classList.contains("selected")) {
tabNavButton = tabNavChildren[tabNavButtonNum]
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_gototab: 60 seconds have passed")
}
await image_browser_delay(200)
}
await image_browser_unlock()
} }
async function image_browser_get_image_for_ext(tab_base_tag, image_index) { async function image_browser_get_image_for_ext(tab_base_tag, image_index) {
@ -228,41 +250,102 @@ function image_browser_openoutpaint_send(tab_base_tag, image_index, image_browse
}) })
} }
async function image_browser_controlnet_send(toTab, tab_base_tag, image_index, controlnetNum) { async function image_browser_controlnet_send(toTab, tab_base_tag, image_index, controlnetNum, controlnetType) {
// Logic originally based on github.com/fkunn1326/openpose-editor
const dataURL = await image_browser_get_image_for_ext(tab_base_tag, image_index) const dataURL = await image_browser_get_image_for_ext(tab_base_tag, image_index)
const blob = await (await fetch(dataURL)).blob() const blob = await (await fetch(dataURL)).blob()
const dt = new DataTransfer() const dt = new DataTransfer()
dt.items.add(new File([blob], "ImageBrowser.png", { type: blob.type })) dt.items.add(new File([blob], "ImageBrowser.png", { type: blob.type }))
const container = gradioApp().querySelector( const list = dt.files
toTab === "txt2img" ? "#txt2img_script_container" : "#img2img_script_container"
)
const accordion = container.querySelector("#controlnet .transition")
if (accordion.classList.contains("rotate-90")) accordion.click()
const tab = container.querySelectorAll( await image_browser_gototab(toTab)
"#controlnet > div:nth-child(2) > .tabs > .tabitem, #controlnet > div:nth-child(2) > div:not(.tabs)" const mode = gradioApp().getElementById(toTab + "_controlnet")
)[controlnetNum] let accordion = mode.querySelector("#controlnet > .label-wrap > .icon")
if (tab.classList.contains("tabitem")) if (accordion.style.transform.includes("rotate(90deg)")) {
tab.parentElement.firstElementChild.querySelector(`:nth-child(${Number(controlnetNum) + 1})`).click() accordion.click()
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
await image_browser_delay(100)
while (accordion.style.transform.includes("rotate(90deg)")) {
accordion = mode.querySelector("#controlnet > .label-wrap > .icon")
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_controlnet_send/accordion: 60 seconds have passed")
}
await image_browser_delay(200)
}
}
const input = tab.querySelector("input[type='file']") let inputContainerSelector
if (controlnetType == "single") {
inputContainerSelector = toTab + "_controlnet_ControlNet_input_image"
} else {
inputContainerSelector = toTab + "_controlnet_ControlNet-" + controlnetNum + "_input_image"
const tabs = gradioApp().getElementById(toTab + "_controlnet_tabs")
const tab_num = (parseInt(controlnetNum) + 1).toString()
tab_button = tabs.querySelector(".tab-nav button:nth-child(" + tab_num + ")")
tab_button.click()
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
await image_browser_delay(100)
while (!tab_button.classList.contains("selected")) {
tab_button = tabs.querySelector(".tab-nav button:nth-child(" + tab_num + ")")
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_controlnet_send/tabs: 60 seconds have passed")
}
await image_browser_delay(200)
}
}
let inputContainer = null
try { try {
input.previousElementSibling.previousElementSibling.querySelector("button[aria-label='Clear']").click() inputContainer = gradioApp().getElementById(inputContainerSelector)
} catch (e) {} } catch (e) {}
const input = inputContainer.querySelector("input[type='file']")
let clear
try {
clear = inputContainer.querySelector("button[aria-label='Clear']")
if (clear) {
clear.click()
}
} catch (e) {
console.error(e)
}
try {
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
while (clear) {
clear = inputContainer.querySelector("button[aria-label='Clear']")
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_controlnet_send/clear: 60 seconds have passed")
}
await image_browser_delay(200)
}
} catch (e) {
console.error(e)
}
input.value = "" input.value = ""
input.files = dt.files input.files = list
input.dispatchEvent(new Event("change", { bubbles: true, composed: true })) const event = new Event("change", { "bubbles": true, "composed": true })
input.dispatchEvent(event)
image_browser_gototab(toTab)
} }
function image_browser_controlnet_send_txt2img(tab_base_tag, image_index, controlnetNum) { function image_browser_controlnet_send_txt2img(tab_base_tag, image_index, controlnetNum, controlnetType) {
image_browser_controlnet_send("txt2img", tab_base_tag, image_index, controlnetNum) image_browser_controlnet_send("txt2img", tab_base_tag, image_index, controlnetNum, controlnetType)
} }
function image_browser_controlnet_send_img2img(tab_base_tag, image_index, controlnetNum) { function image_browser_controlnet_send_img2img(tab_base_tag, image_index, controlnetNum, controlnetType) {
image_browser_controlnet_send("img2img", tab_base_tag, image_index, controlnetNum) image_browser_controlnet_send("img2img", tab_base_tag, image_index, controlnetNum, controlnetType)
} }
function image_browser_class_add(tab_base_tag) { function image_browser_class_add(tab_base_tag) {

View File

@ -1045,7 +1045,13 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
sendto_controlnet_txt2img = gr.Button("Send to txt2img ControlNet", visible=controlnet) sendto_controlnet_txt2img = gr.Button("Send to txt2img ControlNet", visible=controlnet)
sendto_controlnet_img2img = gr.Button("Send to img2img ControlNet", visible=controlnet) sendto_controlnet_img2img = gr.Button("Send to img2img ControlNet", visible=controlnet)
controlnet_max = opts.data.get("control_net_max_models_num", 1) controlnet_max = opts.data.get("control_net_max_models_num", 1)
sendto_controlnet_num = gr.Dropdown(list(range(controlnet_max)), label="ControlNet number", value="0", interactive=True, visible=(controlnet and controlnet_max > 1)) sendto_controlnet_num = gr.Dropdown([str(i) for i in range(controlnet_max)], label="ControlNet number", value="0", interactive=True, visible=(controlnet and controlnet_max > 1))
if controlnet_max is None:
sendto_controlnet_type = gr.Textbox(value="none", visible=False)
elif controlnet_max == 1:
sendto_controlnet_type = gr.Textbox(value="single", visible=False)
else:
sendto_controlnet_type = gr.Textbox(value="multiple", visible=False)
with gr.Row(elem_id=f"{tab.base_tag}_image_browser_to_dir_panel", visible=False) as to_dir_panel: with gr.Row(elem_id=f"{tab.base_tag}_image_browser_to_dir_panel", visible=False) as to_dir_panel:
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
@ -1291,13 +1297,13 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
) )
sendto_controlnet_txt2img.click( sendto_controlnet_txt2img.click(
fn=None, fn=None,
inputs=[tab_base_tag_box, image_index, sendto_controlnet_num], inputs=[tab_base_tag_box, image_index, sendto_controlnet_num, sendto_controlnet_type],
outputs=[], outputs=[],
_js="image_browser_controlnet_send_txt2img" _js="image_browser_controlnet_send_txt2img"
) )
sendto_controlnet_img2img.click( sendto_controlnet_img2img.click(
fn=None, fn=None,
inputs=[tab_base_tag_box, image_index, sendto_controlnet_num], inputs=[tab_base_tag_box, image_index, sendto_controlnet_num, sendto_controlnet_type],
outputs=[], outputs=[],
_js="image_browser_controlnet_send_img2img" _js="image_browser_controlnet_send_img2img"
) )