🔧 Disable ControlNet input in img2img inpaint (#1763)
* 🎨 rename functions * 🔧 Disallow controlnet input in img2img inpaint * ✅ Update img2img inpaint testpull/1764/head
parent
7f6a6d33f9
commit
386adaa87e
|
|
@ -18,8 +18,15 @@
|
|||
return children.indexOf(element);
|
||||
}
|
||||
|
||||
function imageInputDisabledAlert() {
|
||||
alert('Inpaint control type must use a1111 input in img2img mode.');
|
||||
}
|
||||
|
||||
class GradioTab {
|
||||
constructor(tab) {
|
||||
this.tab = tab;
|
||||
this.isImg2Img = tab.querySelector('.cnet-unit-enabled').id.includes('img2img');
|
||||
|
||||
this.enabledCheckbox = tab.querySelector('.cnet-unit-enabled input');
|
||||
this.inputImage = tab.querySelector('.cnet-input-image-group .cnet-image input[type="file"]');
|
||||
this.controlTypeRadios = tab.querySelectorAll('.controlnet_control_type_filter_group input[type="radio"]');
|
||||
|
|
@ -47,7 +54,7 @@
|
|||
return undefined;
|
||||
}
|
||||
|
||||
applyActiveState() {
|
||||
updateActiveState() {
|
||||
const tabNavButton = this.getTabNavButton();
|
||||
if (!tabNavButton) return;
|
||||
|
||||
|
|
@ -61,7 +68,7 @@
|
|||
/**
|
||||
* Add the active control type to tab displayed text.
|
||||
*/
|
||||
applyActiveControlType() {
|
||||
updateActiveControlType() {
|
||||
const tabNavButton = this.getTabNavButton();
|
||||
if (!tabNavButton) return;
|
||||
|
||||
|
|
@ -79,16 +86,41 @@
|
|||
tabNavButton.appendChild(span);
|
||||
}
|
||||
|
||||
/**
|
||||
* When 'Inpaint' control type is selected in img2img:
|
||||
* - Make image input disabled
|
||||
* - Clear existing image input
|
||||
*/
|
||||
updateImageInputState() {
|
||||
if (!this.isImg2Img) return;
|
||||
|
||||
const tabNavButton = this.getTabNavButton();
|
||||
if (!tabNavButton) return;
|
||||
|
||||
const controlType = this.getActiveControlType();
|
||||
if (controlType.toLowerCase() === 'inpaint') {
|
||||
this.inputImage.disabled = true;
|
||||
this.inputImage.parentNode.addEventListener('click', imageInputDisabledAlert);
|
||||
const removeButton = this.tab.querySelector(
|
||||
'.cnet-input-image-group .cnet-image button[aria-label="Remove Image"]');
|
||||
if (removeButton) removeButton.click();
|
||||
} else {
|
||||
this.inputImage.disabled = false;
|
||||
this.inputImage.parentNode.removeEventListener('click', imageInputDisabledAlert);
|
||||
}
|
||||
}
|
||||
|
||||
attachEnabledButtonListener() {
|
||||
this.enabledCheckbox.addEventListener('change', () => {
|
||||
this.applyActiveState();
|
||||
this.updateActiveState();
|
||||
});
|
||||
}
|
||||
|
||||
attachControlTypeRadioListener() {
|
||||
for (const radio of this.controlTypeRadios) {
|
||||
radio.addEventListener('change', () => {
|
||||
this.applyActiveControlType();
|
||||
this.updateActiveControlType();
|
||||
this.updateImageInputState();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -102,8 +134,8 @@
|
|||
new MutationObserver((mutationsList) => {
|
||||
for (const mutation of mutationsList) {
|
||||
if (mutation.type === 'childList') {
|
||||
this.applyActiveState();
|
||||
this.applyActiveControlType();
|
||||
this.updateActiveState();
|
||||
this.updateActiveControlType();
|
||||
}
|
||||
}
|
||||
}).observe(this.tabNav, { childList: true });
|
||||
|
|
|
|||
|
|
@ -90,6 +90,14 @@ class SeleniumTestCase(unittest.TestCase):
|
|||
if not input_image_group.is_displayed():
|
||||
controlnet_panel.click()
|
||||
|
||||
def enable_controlnet_unit(self):
|
||||
controlnet_panel = self.gen_type.controlnet_panel(self.driver)
|
||||
enable_checkbox = controlnet_panel.find_element(
|
||||
By.CSS_SELECTOR, ".cnet-unit-enabled input[type='checkbox']"
|
||||
)
|
||||
if not enable_checkbox.is_selected():
|
||||
enable_checkbox.click()
|
||||
|
||||
def iterate_preprocessor_types(self, ignore_none: bool = True):
|
||||
dropdown = self.gen_type.controlnet_panel(self.driver).find_element(
|
||||
By.CSS_SELECTOR,
|
||||
|
|
@ -277,18 +285,14 @@ class SeleniumImg2ImgTest(SeleniumTestCase):
|
|||
self.generate_image(f"img2img_{control_type}_ski")
|
||||
|
||||
|
||||
class SeleniumInpainTest(SeleniumTestCase):
|
||||
class SeleniumInpaintTest(SeleniumTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.set_seed(100)
|
||||
self.set_subseed(1000)
|
||||
|
||||
def draw_inpaint_mask(
|
||||
self, target_canvas
|
||||
):
|
||||
def draw_inpaint_mask(self, target_canvas):
|
||||
size = target_canvas.size
|
||||
width = size['width']
|
||||
height = size['height']
|
||||
width = size["width"]
|
||||
height = size["height"]
|
||||
brush_radius = 5
|
||||
repeat = int(width * 0.1 / brush_radius)
|
||||
|
||||
|
|
@ -296,7 +300,7 @@ class SeleniumInpainTest(SeleniumTestCase):
|
|||
(brush_radius, 0),
|
||||
(0, height * 0.2),
|
||||
(brush_radius, 0),
|
||||
(0, - height * 0.2)
|
||||
(0, -height * 0.2),
|
||||
] * repeat
|
||||
|
||||
actions = ActionChains(self.driver)
|
||||
|
|
@ -308,7 +312,7 @@ class SeleniumInpainTest(SeleniumTestCase):
|
|||
actions.release() # release the left mouse button
|
||||
actions.perform() # perform the action chain
|
||||
|
||||
def draw_cn_mask(self):
|
||||
def draw_cn_mask(self):
|
||||
canvas = self.gen_type.controlnet_panel(self.driver).find_element(
|
||||
By.CSS_SELECTOR, ".cnet-input-image-group .cnet-image canvas"
|
||||
)
|
||||
|
|
@ -325,18 +329,17 @@ class SeleniumInpainTest(SeleniumTestCase):
|
|||
self.upload_controlnet_input(SKI_IMAGE)
|
||||
self.draw_cn_mask()
|
||||
|
||||
self.set_seed(100)
|
||||
self.set_subseed(1000)
|
||||
|
||||
for option in self.iterate_preprocessor_types():
|
||||
with self.subTest(option=option):
|
||||
self.generate_image(f"{option}_txt2img_ski")
|
||||
|
||||
def test_img2img_inpaint(self):
|
||||
self._test_img2img_inpaint(True, True)
|
||||
self.tearDown()
|
||||
self.setUp()
|
||||
self._test_img2img_inpaint(True, False)
|
||||
self.tearDown()
|
||||
self.setUp()
|
||||
self._test_img2img_inpaint(False, True)
|
||||
# Note: img2img inpaint can only use A1111 mask.
|
||||
# ControlNet input is disabled in img2img inpaint.
|
||||
self._test_img2img_inpaint(use_cn_mask=False, use_a1111_mask=True)
|
||||
|
||||
def _test_img2img_inpaint(self, use_cn_mask: bool, use_a1111_mask: bool):
|
||||
self.select_gen_type(GenType.img2img)
|
||||
|
|
@ -354,8 +357,12 @@ class SeleniumInpainTest(SeleniumTestCase):
|
|||
f"//input[@name='radio-img2img_inpainting_fill' and @value='latent noise']",
|
||||
).click()
|
||||
self.set_prompt("(coca-cola:2.0)")
|
||||
self.enable_controlnet_unit()
|
||||
self.upload_controlnet_input(SKI_IMAGE)
|
||||
|
||||
self.set_seed(100)
|
||||
self.set_subseed(1000)
|
||||
|
||||
prefix = ""
|
||||
if use_cn_mask:
|
||||
self.draw_cn_mask()
|
||||
|
|
|
|||
Loading…
Reference in New Issue