diff --git a/javascript/active_units.js b/javascript/active_units.js index d83ca19..0aef28d 100644 --- a/javascript/active_units.js +++ b/javascript/active_units.js @@ -18,8 +18,15 @@ return children.indexOf(element); } + function imageInputDisabledAlert() { + alert('Inpaint control type must use a1111 input in img2img mode.'); + } + class GradioTab { constructor(tab) { + this.tab = tab; + this.isImg2Img = tab.querySelector('.cnet-unit-enabled').id.includes('img2img'); + this.enabledCheckbox = tab.querySelector('.cnet-unit-enabled input'); this.inputImage = tab.querySelector('.cnet-input-image-group .cnet-image input[type="file"]'); this.controlTypeRadios = tab.querySelectorAll('.controlnet_control_type_filter_group input[type="radio"]'); @@ -47,7 +54,7 @@ return undefined; } - applyActiveState() { + updateActiveState() { const tabNavButton = this.getTabNavButton(); if (!tabNavButton) return; @@ -61,7 +68,7 @@ /** * Add the active control type to tab displayed text. */ - applyActiveControlType() { + updateActiveControlType() { const tabNavButton = this.getTabNavButton(); if (!tabNavButton) return; @@ -79,16 +86,41 @@ tabNavButton.appendChild(span); } + /** + * When 'Inpaint' control type is selected in img2img: + * - Make image input disabled + * - Clear existing image input + */ + updateImageInputState() { + if (!this.isImg2Img) return; + + const tabNavButton = this.getTabNavButton(); + if (!tabNavButton) return; + + const controlType = this.getActiveControlType(); + if (controlType.toLowerCase() === 'inpaint') { + this.inputImage.disabled = true; + this.inputImage.parentNode.addEventListener('click', imageInputDisabledAlert); + const removeButton = this.tab.querySelector( + '.cnet-input-image-group .cnet-image button[aria-label="Remove Image"]'); + if (removeButton) removeButton.click(); + } else { + this.inputImage.disabled = false; + this.inputImage.parentNode.removeEventListener('click', imageInputDisabledAlert); + } + } + attachEnabledButtonListener() { this.enabledCheckbox.addEventListener('change', () => { - this.applyActiveState(); + this.updateActiveState(); }); } attachControlTypeRadioListener() { for (const radio of this.controlTypeRadios) { radio.addEventListener('change', () => { - this.applyActiveControlType(); + this.updateActiveControlType(); + this.updateImageInputState(); }); } } @@ -102,8 +134,8 @@ new MutationObserver((mutationsList) => { for (const mutation of mutationsList) { if (mutation.type === 'childList') { - this.applyActiveState(); - this.applyActiveControlType(); + this.updateActiveState(); + this.updateActiveControlType(); } } }).observe(this.tabNav, { childList: true }); diff --git a/web_tests/main.py b/web_tests/main.py index 9afc24e..ebc24c1 100644 --- a/web_tests/main.py +++ b/web_tests/main.py @@ -90,6 +90,14 @@ class SeleniumTestCase(unittest.TestCase): if not input_image_group.is_displayed(): controlnet_panel.click() + def enable_controlnet_unit(self): + controlnet_panel = self.gen_type.controlnet_panel(self.driver) + enable_checkbox = controlnet_panel.find_element( + By.CSS_SELECTOR, ".cnet-unit-enabled input[type='checkbox']" + ) + if not enable_checkbox.is_selected(): + enable_checkbox.click() + def iterate_preprocessor_types(self, ignore_none: bool = True): dropdown = self.gen_type.controlnet_panel(self.driver).find_element( By.CSS_SELECTOR, @@ -277,18 +285,14 @@ class SeleniumImg2ImgTest(SeleniumTestCase): self.generate_image(f"img2img_{control_type}_ski") -class SeleniumInpainTest(SeleniumTestCase): +class SeleniumInpaintTest(SeleniumTestCase): def setUp(self) -> None: super().setUp() - self.set_seed(100) - self.set_subseed(1000) - def draw_inpaint_mask( - self, target_canvas - ): + def draw_inpaint_mask(self, target_canvas): size = target_canvas.size - width = size['width'] - height = size['height'] + width = size["width"] + height = size["height"] brush_radius = 5 repeat = int(width * 0.1 / brush_radius) @@ -296,7 +300,7 @@ class SeleniumInpainTest(SeleniumTestCase): (brush_radius, 0), (0, height * 0.2), (brush_radius, 0), - (0, - height * 0.2) + (0, -height * 0.2), ] * repeat actions = ActionChains(self.driver) @@ -308,7 +312,7 @@ class SeleniumInpainTest(SeleniumTestCase): actions.release() # release the left mouse button actions.perform() # perform the action chain - def draw_cn_mask(self): + def draw_cn_mask(self): canvas = self.gen_type.controlnet_panel(self.driver).find_element( By.CSS_SELECTOR, ".cnet-input-image-group .cnet-image canvas" ) @@ -325,18 +329,17 @@ class SeleniumInpainTest(SeleniumTestCase): self.upload_controlnet_input(SKI_IMAGE) self.draw_cn_mask() + self.set_seed(100) + self.set_subseed(1000) + for option in self.iterate_preprocessor_types(): with self.subTest(option=option): self.generate_image(f"{option}_txt2img_ski") def test_img2img_inpaint(self): - self._test_img2img_inpaint(True, True) - self.tearDown() - self.setUp() - self._test_img2img_inpaint(True, False) - self.tearDown() - self.setUp() - self._test_img2img_inpaint(False, True) + # Note: img2img inpaint can only use A1111 mask. + # ControlNet input is disabled in img2img inpaint. + self._test_img2img_inpaint(use_cn_mask=False, use_a1111_mask=True) def _test_img2img_inpaint(self, use_cn_mask: bool, use_a1111_mask: bool): self.select_gen_type(GenType.img2img) @@ -354,8 +357,12 @@ class SeleniumInpainTest(SeleniumTestCase): f"//input[@name='radio-img2img_inpainting_fill' and @value='latent noise']", ).click() self.set_prompt("(coca-cola:2.0)") + self.enable_controlnet_unit() self.upload_controlnet_input(SKI_IMAGE) + self.set_seed(100) + self.set_subseed(1000) + prefix = "" if use_cn_mask: self.draw_cn_mask()