🔧 Disable ControlNet input in img2img inpaint (#1763)

* 🎨 rename functions

* 🔧 Disallow controlnet input in img2img inpaint

*  Update img2img inpaint test
pull/1764/head
Chenlei Hu 2023-07-04 22:24:31 -04:00 committed by GitHub
parent 7f6a6d33f9
commit 386adaa87e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 23 deletions

View File

@ -18,8 +18,15 @@
return children.indexOf(element);
}
function imageInputDisabledAlert() {
alert('Inpaint control type must use a1111 input in img2img mode.');
}
class GradioTab {
constructor(tab) {
this.tab = tab;
this.isImg2Img = tab.querySelector('.cnet-unit-enabled').id.includes('img2img');
this.enabledCheckbox = tab.querySelector('.cnet-unit-enabled input');
this.inputImage = tab.querySelector('.cnet-input-image-group .cnet-image input[type="file"]');
this.controlTypeRadios = tab.querySelectorAll('.controlnet_control_type_filter_group input[type="radio"]');
@ -47,7 +54,7 @@
return undefined;
}
applyActiveState() {
updateActiveState() {
const tabNavButton = this.getTabNavButton();
if (!tabNavButton) return;
@ -61,7 +68,7 @@
/**
* Add the active control type to tab displayed text.
*/
applyActiveControlType() {
updateActiveControlType() {
const tabNavButton = this.getTabNavButton();
if (!tabNavButton) return;
@ -79,16 +86,41 @@
tabNavButton.appendChild(span);
}
/**
* When 'Inpaint' control type is selected in img2img:
* - Make image input disabled
* - Clear existing image input
*/
updateImageInputState() {
if (!this.isImg2Img) return;
const tabNavButton = this.getTabNavButton();
if (!tabNavButton) return;
const controlType = this.getActiveControlType();
if (controlType.toLowerCase() === 'inpaint') {
this.inputImage.disabled = true;
this.inputImage.parentNode.addEventListener('click', imageInputDisabledAlert);
const removeButton = this.tab.querySelector(
'.cnet-input-image-group .cnet-image button[aria-label="Remove Image"]');
if (removeButton) removeButton.click();
} else {
this.inputImage.disabled = false;
this.inputImage.parentNode.removeEventListener('click', imageInputDisabledAlert);
}
}
attachEnabledButtonListener() {
this.enabledCheckbox.addEventListener('change', () => {
this.applyActiveState();
this.updateActiveState();
});
}
attachControlTypeRadioListener() {
for (const radio of this.controlTypeRadios) {
radio.addEventListener('change', () => {
this.applyActiveControlType();
this.updateActiveControlType();
this.updateImageInputState();
});
}
}
@ -102,8 +134,8 @@
new MutationObserver((mutationsList) => {
for (const mutation of mutationsList) {
if (mutation.type === 'childList') {
this.applyActiveState();
this.applyActiveControlType();
this.updateActiveState();
this.updateActiveControlType();
}
}
}).observe(this.tabNav, { childList: true });

View File

@ -90,6 +90,14 @@ class SeleniumTestCase(unittest.TestCase):
if not input_image_group.is_displayed():
controlnet_panel.click()
def enable_controlnet_unit(self):
controlnet_panel = self.gen_type.controlnet_panel(self.driver)
enable_checkbox = controlnet_panel.find_element(
By.CSS_SELECTOR, ".cnet-unit-enabled input[type='checkbox']"
)
if not enable_checkbox.is_selected():
enable_checkbox.click()
def iterate_preprocessor_types(self, ignore_none: bool = True):
dropdown = self.gen_type.controlnet_panel(self.driver).find_element(
By.CSS_SELECTOR,
@ -277,18 +285,14 @@ class SeleniumImg2ImgTest(SeleniumTestCase):
self.generate_image(f"img2img_{control_type}_ski")
class SeleniumInpainTest(SeleniumTestCase):
class SeleniumInpaintTest(SeleniumTestCase):
def setUp(self) -> None:
super().setUp()
self.set_seed(100)
self.set_subseed(1000)
def draw_inpaint_mask(
self, target_canvas
):
def draw_inpaint_mask(self, target_canvas):
size = target_canvas.size
width = size['width']
height = size['height']
width = size["width"]
height = size["height"]
brush_radius = 5
repeat = int(width * 0.1 / brush_radius)
@ -296,7 +300,7 @@ class SeleniumInpainTest(SeleniumTestCase):
(brush_radius, 0),
(0, height * 0.2),
(brush_radius, 0),
(0, - height * 0.2)
(0, -height * 0.2),
] * repeat
actions = ActionChains(self.driver)
@ -325,18 +329,17 @@ class SeleniumInpainTest(SeleniumTestCase):
self.upload_controlnet_input(SKI_IMAGE)
self.draw_cn_mask()
self.set_seed(100)
self.set_subseed(1000)
for option in self.iterate_preprocessor_types():
with self.subTest(option=option):
self.generate_image(f"{option}_txt2img_ski")
def test_img2img_inpaint(self):
self._test_img2img_inpaint(True, True)
self.tearDown()
self.setUp()
self._test_img2img_inpaint(True, False)
self.tearDown()
self.setUp()
self._test_img2img_inpaint(False, True)
# Note: img2img inpaint can only use A1111 mask.
# ControlNet input is disabled in img2img inpaint.
self._test_img2img_inpaint(use_cn_mask=False, use_a1111_mask=True)
def _test_img2img_inpaint(self, use_cn_mask: bool, use_a1111_mask: bool):
self.select_gen_type(GenType.img2img)
@ -354,8 +357,12 @@ class SeleniumInpainTest(SeleniumTestCase):
f"//input[@name='radio-img2img_inpainting_fill' and @value='latent noise']",
).click()
self.set_prompt("(coca-cola:2.0)")
self.enable_controlnet_unit()
self.upload_controlnet_input(SKI_IMAGE)
self.set_seed(100)
self.set_subseed(1000)
prefix = ""
if use_cn_mask:
self.draw_cn_mask()