Rework controlnet integration after Gradio upgrade, #100. #120

controlnet
AlUlkesh 2023-04-03 14:55:44 +02:00
parent db40a6e925
commit 7646dd813c
2 changed files with 121 additions and 32 deletions

View File

@ -178,14 +178,36 @@ async function image_browser_turnpage(tab_base_tag) {
await image_browser_unlock()
}
function image_browser_gototab(tabname, tabsId = "tabs") {
Array.from(
gradioApp().querySelectorAll(`#${tabsId} > div:first-child button`)
).forEach((button) => {
if (button.textContent.trim() === tabname) {
button.click()
}
})
async function image_browser_gototab(tabname) {
await image_browser_lock("image_browser_gototab")
tabNav = gradioApp().querySelector(".tab-nav")
const tabNavChildren = tabNav.children
let tabNavButtonNum
for (let i = 0; i < tabNavChildren.length; i++) {
if (tabNavChildren[i].tagName === "BUTTON" && tabNavChildren[i].textContent.trim() === tabname) {
tabNavButtonNum = i
break
}
}
let tabNavButton = tabNavChildren[tabNavButtonNum]
tabNavButton.click()
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
await image_browser_delay(100)
while (!tabNavButton.classList.contains("selected")) {
tabNavButton = tabNavChildren[tabNavButtonNum]
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_gototab: 60 seconds have passed")
}
await image_browser_delay(200)
}
await image_browser_unlock()
}
async function image_browser_get_image_for_ext(tab_base_tag, image_index) {
@ -228,41 +250,102 @@ function image_browser_openoutpaint_send(tab_base_tag, image_index, image_browse
})
}
async function image_browser_controlnet_send(toTab, tab_base_tag, image_index, controlnetNum) {
async function image_browser_controlnet_send(toTab, tab_base_tag, image_index, controlnetNum, controlnetType) {
// Logic originally based on github.com/fkunn1326/openpose-editor
const dataURL = await image_browser_get_image_for_ext(tab_base_tag, image_index)
const blob = await (await fetch(dataURL)).blob()
const dt = new DataTransfer()
dt.items.add(new File([blob], "ImageBrowser.png", { type: blob.type }))
const container = gradioApp().querySelector(
toTab === "txt2img" ? "#txt2img_script_container" : "#img2img_script_container"
)
const accordion = container.querySelector("#controlnet .transition")
if (accordion.classList.contains("rotate-90")) accordion.click()
const list = dt.files
const tab = container.querySelectorAll(
"#controlnet > div:nth-child(2) > .tabs > .tabitem, #controlnet > div:nth-child(2) > div:not(.tabs)"
)[controlnetNum]
if (tab.classList.contains("tabitem"))
tab.parentElement.firstElementChild.querySelector(`:nth-child(${Number(controlnetNum) + 1})`).click()
await image_browser_gototab(toTab)
const mode = gradioApp().getElementById(toTab + "_controlnet")
let accordion = mode.querySelector("#controlnet > .label-wrap > .icon")
if (accordion.style.transform.includes("rotate(90deg)")) {
accordion.click()
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
const input = tab.querySelector("input[type='file']")
await image_browser_delay(100)
while (accordion.style.transform.includes("rotate(90deg)")) {
accordion = mode.querySelector("#controlnet > .label-wrap > .icon")
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_controlnet_send/accordion: 60 seconds have passed")
}
await image_browser_delay(200)
}
}
let inputContainerSelector
if (controlnetType == "single") {
inputContainerSelector = toTab + "_controlnet_ControlNet_input_image"
} else {
inputContainerSelector = toTab + "_controlnet_ControlNet-" + controlnetNum + "_input_image"
const tabs = gradioApp().getElementById(toTab + "_controlnet_tabs")
const tab_num = (parseInt(controlnetNum) + 1).toString()
tab_button = tabs.querySelector(".tab-nav button:nth-child(" + tab_num + ")")
tab_button.click()
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
await image_browser_delay(100)
while (!tab_button.classList.contains("selected")) {
tab_button = tabs.querySelector(".tab-nav button:nth-child(" + tab_num + ")")
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_controlnet_send/tabs: 60 seconds have passed")
}
await image_browser_delay(200)
}
}
let inputContainer = null
try {
input.previousElementSibling.previousElementSibling.querySelector("button[aria-label='Clear']").click()
inputContainer = gradioApp().getElementById(inputContainerSelector)
} catch (e) {}
const input = inputContainer.querySelector("input[type='file']")
let clear
try {
clear = inputContainer.querySelector("button[aria-label='Clear']")
if (clear) {
clear.click()
}
} catch (e) {
console.error(e)
}
try {
// Wait for click-action to complete
const startTime = Date.now()
// 60 seconds in milliseconds
const timeout = 60000
while (clear) {
clear = inputContainer.querySelector("button[aria-label='Clear']")
if (Date.now() - startTime > timeout) {
throw new Error("image_browser_controlnet_send/clear: 60 seconds have passed")
}
await image_browser_delay(200)
}
} catch (e) {
console.error(e)
}
input.value = ""
input.files = dt.files
input.dispatchEvent(new Event("change", { bubbles: true, composed: true }))
image_browser_gototab(toTab)
input.files = list
const event = new Event("change", { "bubbles": true, "composed": true })
input.dispatchEvent(event)
}
function image_browser_controlnet_send_txt2img(tab_base_tag, image_index, controlnetNum) {
image_browser_controlnet_send("txt2img", tab_base_tag, image_index, controlnetNum)
function image_browser_controlnet_send_txt2img(tab_base_tag, image_index, controlnetNum, controlnetType) {
image_browser_controlnet_send("txt2img", tab_base_tag, image_index, controlnetNum, controlnetType)
}
function image_browser_controlnet_send_img2img(tab_base_tag, image_index, controlnetNum) {
image_browser_controlnet_send("img2img", tab_base_tag, image_index, controlnetNum)
function image_browser_controlnet_send_img2img(tab_base_tag, image_index, controlnetNum, controlnetType) {
image_browser_controlnet_send("img2img", tab_base_tag, image_index, controlnetNum, controlnetType)
}
function image_browser_class_add(tab_base_tag) {

View File

@ -1045,7 +1045,13 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
sendto_controlnet_txt2img = gr.Button("Send to txt2img ControlNet", visible=controlnet)
sendto_controlnet_img2img = gr.Button("Send to img2img ControlNet", visible=controlnet)
controlnet_max = opts.data.get("control_net_max_models_num", 1)
sendto_controlnet_num = gr.Dropdown(list(range(controlnet_max)), label="ControlNet number", value="0", interactive=True, visible=(controlnet and controlnet_max > 1))
sendto_controlnet_num = gr.Dropdown([str(i) for i in range(controlnet_max)], label="ControlNet number", value="0", interactive=True, visible=(controlnet and controlnet_max > 1))
if controlnet_max is None:
sendto_controlnet_type = gr.Textbox(value="none", visible=False)
elif controlnet_max == 1:
sendto_controlnet_type = gr.Textbox(value="single", visible=False)
else:
sendto_controlnet_type = gr.Textbox(value="multiple", visible=False)
with gr.Row(elem_id=f"{tab.base_tag}_image_browser_to_dir_panel", visible=False) as to_dir_panel:
with gr.Box():
with gr.Row():
@ -1291,13 +1297,13 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
)
sendto_controlnet_txt2img.click(
fn=None,
inputs=[tab_base_tag_box, image_index, sendto_controlnet_num],
inputs=[tab_base_tag_box, image_index, sendto_controlnet_num, sendto_controlnet_type],
outputs=[],
_js="image_browser_controlnet_send_txt2img"
)
sendto_controlnet_img2img.click(
fn=None,
inputs=[tab_base_tag_box, image_index, sendto_controlnet_num],
inputs=[tab_base_tag_box, image_index, sendto_controlnet_num, sendto_controlnet_type],
outputs=[],
_js="image_browser_controlnet_send_img2img"
)