Validate ControlNetUnit using pydantic (#2847)

* Add Pydantic ControlNetUnit

Add test config

Add images field

Adjust image field handling

fix various ui

Fix most UI issues

accept greyscale image/mask

Fix infotext

Fix preset

Fix infotext

nit

Move infotext parsing test

Remove preset

Remove unused js code

Adjust test payload

By default disable unit

refresh enum usage

Align resize mode

change test func name

remove unused import

nit

Change default handling

Skip bound check when not enabled

Fix batch

Various batch fix

Disable batch hijack test

adjust test

fix test expectations

Fix unit copy

nit

Fix test failures

* Change script args back to ControlNetUnit for compatibility

* import enum for compatibility

* Fix unit test

* simplify unfold

* Add test coverage

* handle directly set np image

* re-enable batch test

* Add back canvas scribble support

* nit

* Fix batch hijack test
pull/2853/head
Chenlei Hu 2024-05-06 15:40:34 -04:00 committed by GitHub
parent 1b95e476ec
commit e33c046158
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 937 additions and 1102 deletions

View File

@ -94,6 +94,10 @@ jobs:
wait-for-it --service 127.0.0.1:7860 -t 600
python -m pytest -v --junitxml=test/results.xml --cov ./extensions/sd-webui-controlnet --cov-report=xml --verify-base-url ./extensions/sd-webui-controlnet/tests
working-directory: stable-diffusion-webui
- name: Run unit tests
run: |
python -m pytest -v ./unit_tests/
working-directory: stable-diffusion-webui/extensions/sd-webui-controlnet/
- name: Kill test server
if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10

View File

443
internal_controlnet/args.py Normal file
View File

@ -0,0 +1,443 @@
from __future__ import annotations
import os
import torch
import numpy as np
from typing import Optional, List, Annotated, ClassVar, Callable, Any, Tuple, Union
from pydantic import BaseModel, validator, root_validator, Field
from PIL import Image
from logging import Logger
from copy import copy
from enum import Enum
from scripts.enums import (
InputMode,
ResizeMode,
ControlMode,
HiResFixOption,
PuLIDMode,
)
def _unimplemented_func(*args, **kwargs):
raise NotImplementedError("Not implemented.")
def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])
def displaytext_to_field(text: str) -> str:
return "_".join([word.lower() for word in text.split(" ")])
def serialize_value(value) -> str:
if isinstance(value, Enum):
return value.value
return str(value)
def parse_value(value: str) -> Union[str, float, int, bool]:
if value in ("True", "False"):
return value == "True"
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value # Plain string.
class ControlNetUnit(BaseModel):
"""
Represents an entire ControlNet processing unit.
"""
class Config:
arbitrary_types_allowed = True
extra = "ignore"
cls_match_module: ClassVar[Callable[[str], bool]] = _unimplemented_func
cls_match_model: ClassVar[Callable[[str], bool]] = _unimplemented_func
cls_decode_base64: ClassVar[Callable[[str], np.ndarray]] = _unimplemented_func
cls_torch_load_base64: ClassVar[Callable[[Any], torch.Tensor]] = _unimplemented_func
cls_get_preprocessor: ClassVar[Callable[[str], Any]] = _unimplemented_func
cls_logger: ClassVar[Logger] = Logger("ControlNetUnit")
# UI only fields.
is_ui: bool = False
input_mode: InputMode = InputMode.SIMPLE
batch_images: Optional[Any] = None
output_dir: str = ""
loopback: bool = False
# General fields.
enabled: bool = False
module: str = "none"
@validator("module", always=True, pre=True)
def check_module(cls, value: str) -> str:
if not ControlNetUnit.cls_match_module(value):
raise ValueError(f"module({value}) not found in supported modules.")
return value
model: str = "None"
@validator("model", always=True, pre=True)
def check_model(cls, value: str) -> str:
if not ControlNetUnit.cls_match_model(value):
raise ValueError(f"model({value}) not found in supported models.")
return value
weight: Annotated[float, Field(ge=0.0, le=2.0)] = 1.0
# The image to be used for this ControlNetUnit.
image: Optional[Any] = None
resize_mode: ResizeMode = ResizeMode.INNER_FIT
low_vram: bool = False
processor_res: int = -1
threshold_a: float = -1
threshold_b: float = -1
@root_validator
def bound_check_params(cls, values: dict) -> dict:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit' in place.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
"""
enabled = values.get("enabled")
if not enabled:
return values
module = values.get("module")
if not module:
return values
preprocessor = cls.cls_get_preprocessor(module)
assert preprocessor is not None
for unit_param, param in zip(
("processor_res", "threshold_a", "threshold_b"),
("slider_resolution", "slider_1", "slider_2"),
):
value = values.get(unit_param)
cfg = getattr(preprocessor, param)
if value < cfg.minimum or value > cfg.maximum:
values[unit_param] = cfg.value
# Only report warning when non-default value is used.
if value != -1:
cls.cls_logger.info(
f"[{module}.{unit_param}] Invalid value({value}), using default value {cfg.value}."
)
return values
guidance_start: Annotated[float, Field(ge=0.0, le=1.0)] = 0.0
guidance_end: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0
@root_validator
def guidance_check(cls, values: dict) -> dict:
start = values.get("guidance_start")
end = values.get("guidance_end")
if start > end:
raise ValueError(f"guidance_start({start}) > guidance_end({end})")
return values
pixel_perfect: bool = False
control_mode: ControlMode = ControlMode.BALANCED
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
inpaint_crop_input_image: bool = True
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: HiResFixOption = HiResFixOption.BOTH
# Whether save the detected map of this unit. Setting this option to False prevents saving the
# detected map or sending detected map along with generated images via API.
# Currently the option is only accessible in API calls.
save_detected_map: bool = True
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# For IPAdapter
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
# The effective region mask that unit's effect should be restricted to.
effective_region_mask: Optional[np.ndarray] = None
@validator("effective_region_mask", pre=True)
def parse_effective_region_mask(cls, value) -> np.ndarray:
if isinstance(value, str):
return cls.cls_decode_base64(value)
assert isinstance(value, np.ndarray) or value is None
return value
# The weight mode for PuLID.
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
# ------- API only fields -------
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
# Currently the option is only accessible in API calls.
ipadapter_input: Optional[List[torch.Tensor]] = None
@validator("ipadapter_input", pre=True)
def parse_ipadapter_input(cls, value) -> Optional[List[torch.Tensor]]:
if value is None:
return None
if isinstance(value, str):
value = [value]
result = [cls.cls_torch_load_base64(b) for b in value]
assert result, "input cannot be empty"
return result
# The mask to be used on top of the image.
mask: Optional[Any] = None
@property
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter-auto",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"ip-adapter_pulid",
"instant_id_face_embedding",
)
@property
def is_animate_diff_batch(self) -> bool:
return getattr(self, "animatediff_batch", False)
@property
def uses_clip(self) -> bool:
"""Whether this unit uses clip preprocessor."""
return any(
(
("ip-adapter" in self.module and "face_id" not in self.module),
self.module
in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
)
)
@property
def is_inpaint(self) -> bool:
return "inpaint" in self.module
def get_actual_preprocessor(self):
if self.module == "ip-adapter-auto":
return ControlNetUnit.cls_get_preprocessor(
self.module
).get_preprocessor_by_model(self.model)
return ControlNetUnit.cls_get_preprocessor(self.module)
@classmethod
def parse_image(cls, image) -> np.ndarray:
if isinstance(image, np.ndarray):
np_image = image
elif isinstance(image, str):
# Necessary for batch.
if os.path.exists(image):
np_image = np.array(Image.open(image)).astype("uint8")
else:
np_image = cls.cls_decode_base64(image)
else:
raise ValueError(f"Unrecognized image format {image}.")
# [H, W] => [H, W, 3]
if np_image.ndim == 2:
np_image = np.stack([np_image, np_image, np_image], axis=-1)
assert np_image.ndim == 3
assert np_image.shape[2] == 3
return np_image
@classmethod
def combine_image_and_mask(
cls, np_image: np.ndarray, np_mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""RGB + Alpha(Optional) => RGBA"""
# TODO: Change protocol to use 255 as A channel value.
# Note: mask is by default zeros, as both inpaint and
# clip mask does extra work on masked area.
np_mask = (np.zeros_like(np_image) if np_mask is None else np_mask)[:, :, 0:1]
if np_image.shape[:2] != np_mask.shape[:2]:
raise ValueError(
f"image shape ({np_image.shape[:2]}) not aligned with mask shape ({np_mask.shape[:2]})"
)
return np.concatenate([np_image, np_mask], axis=2) # [H, W, 4]
@classmethod
def legacy_field_alias(cls, values: dict) -> dict:
ext_compat_keys = {
"guidance": "guidance_end",
"lowvram": "low_vram",
"input_image": "image",
}
for alias, key in ext_compat_keys.items():
if alias in values:
assert key not in values, f"Conflict of field '{alias}' and '{key}'"
values[key] = alias
cls.cls_logger.warn(
f"Deprecated alias '{alias}' detected. This field will be removed on 2024-06-01"
f"Please use '{key}' instead."
)
return values
@classmethod
def mask_alias(cls, values: dict) -> dict:
"""
Field "mask_image" is the alias of field "mask".
This is for compatibility with SD Forge API.
"""
mask_image = values.get("mask_image")
mask = values.get("mask")
if mask_image is not None:
if mask is not None:
raise ValueError("Cannot specify both 'mask' and 'mask_image'!")
values["mask"] = mask_image
return values
def get_input_images_rgba(self) -> Optional[List[np.ndarray]]:
"""
RGBA images with potentially different size.
Why we cannot have [B, H, W, C=4] here is that calculation of final
resolution requires generation target's dimensions.
Parse image with following formats.
API
- image = {"image": base64image, "mask": base64image,}
- image = [image, mask]
- image = (image, mask)
- image = [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...]
- image = base64image, mask = base64image
UI:
- image = {"image": np_image, "mask": np_image,}
- image = np_image, mask = np_image
"""
init_image = self.image
init_mask = self.mask
if init_image is None:
assert init_mask is None
return None
if isinstance(init_image, (list, tuple)):
if not init_image:
raise ValueError(f"{init_image} is not a valid 'image' field value")
if isinstance(init_image[0], dict):
# [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...]
images = init_image
else:
assert len(init_image) == 2
# [image, mask]
# (image, mask)
images = [
{
"image": init_image[0],
"mask": init_image[1],
}
]
elif isinstance(init_image, dict):
# {"image": ..., "mask": ...}
images = [init_image]
elif isinstance(init_image, (str, np.ndarray)):
# image = base64image, mask = base64image
images = [
{
"image": init_image,
"mask": init_mask,
}
]
else:
raise ValueError(f"Unrecognized image field {init_image}")
np_images = []
for image_dict in images:
assert isinstance(image_dict, dict)
image = image_dict.get("image")
mask = image_dict.get("mask")
assert image is not None
np_image = self.parse_image(image)
np_mask = self.parse_image(mask) if mask is not None else None
np_images.append(self.combine_image_and_mask(np_image, np_mask)) # [H, W, 4]
return np_images
@classmethod
def from_dict(cls, values: dict) -> ControlNetUnit:
values = copy(values)
values = cls.legacy_field_alias(values)
values = cls.mask_alias(values)
return ControlNetUnit(**values)
@classmethod
def from_infotext_args(cls, *args) -> ControlNetUnit:
assert len(args) == len(ControlNetUnit.infotext_fields())
return cls.from_dict(
{k: v for k, v in zip(ControlNetUnit.infotext_fields(), args)}
)
@staticmethod
def infotext_fields() -> Tuple[str]:
"""Fields that should be included in infotext.
You should define a Gradio element with exact same name in ControlNetUiGroup
as well, so that infotext can wire the value to correct field when pasting
infotext.
"""
return (
"module",
"model",
"weight",
"resize_mode",
"processor_res",
"threshold_a",
"threshold_b",
"guidance_start",
"guidance_end",
"pixel_perfect",
"control_mode",
)
def serialize(self) -> str:
"""Serialize the unit for infotext."""
infotext_dict = {
field_to_displaytext(field): serialize_value(getattr(self, field))
for field in ControlNetUnit.infotext_fields()
}
if not all(
"," not in str(v) and ":" not in str(v) for v in infotext_dict.values()
):
self.cls_logger.error(f"Unexpected tokens encountered:\n{infotext_dict}")
return ""
return ", ".join(f"{field}: {value}" for field, value in infotext_dict.items())
@classmethod
def parse(cls, text: str) -> ControlNetUnit:
return ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
for item in text.split(",")
for (key, value) in (item.strip().split(": "),)
},
)

View File

@ -1,23 +1,30 @@
import base64
import io
from dataclasses import dataclass
from copy import copy
from typing import List, Any, Optional, Union, Tuple, Dict
import torch
import numpy as np
from modules import scripts, processing, shared
from modules.safe import unsafe_torch_load
from modules.api import api
from .args import ControlNetUnit
from scripts import global_state
from scripts.logging import logger
from scripts.enums import HiResFixOption, PuLIDMode, ControlMode, ResizeMode
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter
from scripts.enums import (
ResizeMode,
BatchOption, # noqa: F401
ControlMode, # noqa: F401
)
from scripts.supported_preprocessor import (
Preprocessor,
PreprocessorParameter, # noqa: F401
)
from modules.api import api
import torch
import base64
import io
from modules.safe import unsafe_torch_load
def get_api_version() -> int:
return 2
return 3
resize_mode_aliases = {
@ -47,15 +54,6 @@ def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
return value
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode:
if isinstance(value, str):
return ControlMode(value)
elif isinstance(value, int):
return [e for e in ControlMode][value]
else:
return value
def visualize_inpaint_mask(img):
if img.ndim == 3 and img.shape[2] == 4:
result = img.copy()
@ -117,153 +115,7 @@ def pixel_perfect_resolution(
return int(np.round(estimation))
InputImage = Union[np.ndarray, str]
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage]
@dataclass
class ControlNetUnit:
"""
Represents an entire ControlNet processing unit.
"""
enabled: bool = True
module: str = "none"
model: str = "None"
weight: float = 1.0
image: Optional[Union[InputImage, List[InputImage]]] = None
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
low_vram: bool = False
processor_res: int = -1
threshold_a: float = -1
threshold_b: float = -1
guidance_start: float = 0.0
guidance_end: float = 1.0
pixel_perfect: bool = False
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
inpaint_crop_input_image: bool = True
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
# Whether save the detected map of this unit. Setting this option to False prevents saving the
# detected map or sending detected map along with generated images via API.
# Currently the option is only accessible in API calls.
save_detected_map: bool = True
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# For IPAdapter
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
# The effective region mask that unit's effect should be restricted to.
effective_region_mask: Optional[np.ndarray] = None
# The weight mode for PuLID.
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
# Currently the option is only accessible in API calls.
ipadapter_input: Optional[List[Any]] = None
def __eq__(self, other):
if not isinstance(other, ControlNetUnit):
return False
return vars(self) == vars(other)
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter-auto",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"ip-adapter_pulid",
"instant_id_face_embedding",
)
@staticmethod
def infotext_excluded_fields() -> List[str]:
return [
"image",
"enabled",
# API-only fields.
"advanced_weighting",
"ipadapter_input",
# End of API-only fields.
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
# provide much information when restoring the unit.
"inpaint_crop_input_image",
"effective_region_mask",
"pulid_mode",
]
@property
def is_animate_diff_batch(self) -> bool:
return getattr(self, "animatediff_batch", False)
@property
def uses_clip(self) -> bool:
"""Whether this unit uses clip preprocessor."""
return any(
(
("ip-adapter" in self.module and "face_id" not in self.module),
self.module
in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
)
)
@property
def is_inpaint(self) -> bool:
return "inpaint" in self.module
def bound_check_params(self) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit' in place.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
"""
preprocessor = Preprocessor.get_preprocessor(self.module)
for unit_param, param in zip(
("processor_res", "threshold_a", "threshold_b"),
("slider_resolution", "slider_1", "slider_2"),
):
value = getattr(self, unit_param)
cfg: PreprocessorParameter = getattr(preprocessor, param)
if value < 0:
setattr(self, unit_param, cfg.value)
logger.info(
f"[{self.module}.{unit_param}] Invalid value({value}), using default value {cfg.value}."
)
def get_actual_preprocessor(self) -> Preprocessor:
if self.module == "ip-adapter-auto":
return Preprocessor.get_preprocessor(self.module).get_preprocessor_by_model(
self.model
)
return Preprocessor.get_preprocessor(self.module)
def to_base64_nparray(encoding: str):
def to_base64_nparray(encoding: str) -> np.ndarray:
"""
Convert a base64 image into the image type the extension uses
"""
@ -368,73 +220,14 @@ def get_max_models_num():
return max_models_num
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit:
def to_processing_unit(unit: Union[Dict, ControlNetUnit]) -> ControlNetUnit:
"""
Convert different types to processing unit.
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details.
"""
ext_compat_keys = {
"guessmode": "guess_mode",
"guidance": "guidance_end",
"lowvram": "low_vram",
"input_image": "image",
}
if isinstance(unit, dict):
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
return ControlNetUnit.from_dict(unit)
# Handle mask
mask = None
if "mask" in unit:
mask = unit["mask"]
del unit["mask"]
if "mask_image" in unit:
mask = unit["mask_image"]
del unit["mask_image"]
if "image" in unit and not isinstance(unit["image"], dict):
unit["image"] = (
{"image": unit["image"], "mask": mask}
if mask is not None
else unit["image"] if unit["image"] else None
)
# Parse ipadapter_input
if "ipadapter_input" in unit and unit["ipadapter_input"] is not None:
def decode_base64(b: str) -> torch.Tensor:
decoded_bytes = base64.b64decode(b)
return unsafe_torch_load(io.BytesIO(decoded_bytes))
if isinstance(unit["ipadapter_input"], str):
unit["ipadapter_input"] = [unit["ipadapter_input"]]
unit["ipadapter_input"] = [
decode_base64(b) for b in unit["ipadapter_input"]
]
if unit.get("effective_region_mask", None) is not None:
base64img = unit["effective_region_mask"]
assert isinstance(base64img, str)
unit["effective_region_mask"] = to_base64_nparray(base64img)
if "guess_mode" in unit:
logger.warning(
"Guess Mode is removed since 1.1.136. Please use Control Mode instead."
)
for k in unit.keys():
if k not in vars(ControlNetUnit):
logger.warn(f"Received unrecognized key '{k}' in API.")
unit = ControlNetUnit(
**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()}
)
# temporary, check #602
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]'
assert isinstance(unit, ControlNetUnit)
return unit
@ -621,3 +414,23 @@ def is_cn_script(script: scripts.Script) -> bool:
"""
return script.title().lower() == "controlnet"
# TODO: Add model constraint
ControlNetUnit.cls_match_model = lambda model: True
ControlNetUnit.cls_match_module = (
lambda module: Preprocessor.get_preprocessor(module) is not None
)
ControlNetUnit.cls_get_preprocessor = Preprocessor.get_preprocessor
ControlNetUnit.cls_decode_base64 = to_base64_nparray
def decode_base64(b: str) -> torch.Tensor:
decoded_bytes = base64.b64decode(b)
return unsafe_torch_load(io.BytesIO(decoded_bytes))
ControlNetUnit.cls_torch_load_base64 = decode_base64
ControlNetUnit.cls_logger = logger
logger.debug("ControlNetUnit initialized")

View File

@ -85,7 +85,6 @@
this.attachImageUploadListener();
this.attachImageStateChangeObserver();
this.attachA1111SendInfoObserver();
this.attachPresetDropdownObserver();
}
getTabNavButton() {
@ -303,26 +302,6 @@
});
}
}
attachPresetDropdownObserver() {
const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown');
new MutationObserver((mutationsList) => {
for (const mutation of mutationsList) {
if (mutation.removedNodes.length > 0) {
setTimeout(() => {
this.updateActiveState();
this.updateActiveUnitCount();
this.updateActiveControlType();
}, 1000);
return;
}
}
}).observe(presetDropDown, {
childList: true,
subtree: true,
});
}
}
gradioApp().querySelectorAll('#controlnet').forEach(accordion => {

View File

@ -137,12 +137,12 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
)
unit = ControlNetUnit(
enabled=True,
module=preprocessor.label,
processor_res=controlnet_processor_res,
threshold_a=controlnet_threshold_a,
threshold_b=controlnet_threshold_b,
)
unit.bound_check_params()
tensors = []
images = []

View File

@ -1,10 +1,10 @@
import os
from copy import copy
from typing import Tuple, List
from modules import img2img, processing, shared, script_callbacks
from scripts import external_code
from scripts.enums import InputMode
from scripts.logging import logger
class BatchHijack:
def __init__(self):
@ -194,7 +194,7 @@ def unhijack_function(module, name, new_name):
def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]:
units = external_code.get_all_units_in_processing(p)
units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)]
units = [unit.copy() for unit in units if getattr(unit, 'enabled', False)]
any_unit_is_batch = False
output_dir = ''
input_file_names = []
@ -222,6 +222,8 @@ def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[
else:
batches[i].append(unit.image)
if any_unit_is_batch:
logger.info(f"Batch enabled ({len(batches)})")
return any_unit_is_batch, batches, output_dir, input_file_names

View File

@ -4,10 +4,9 @@ import os
import logging
from collections import OrderedDict
from copy import copy, deepcopy
from typing import Dict, Optional, Tuple, List, Union
from typing import Dict, Optional, Tuple, List
import modules.scripts as scripts
from modules import shared, devices, script_callbacks, processing, masking, images
from modules.api.api import decode_base64_to_image
import gradio as gr
import time
@ -26,6 +25,7 @@ from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import (
ControlModelType,
InputMode,
StableDiffusionVersion,
HiResFixOption,
PuLIDMode,
@ -33,7 +33,7 @@ from scripts.enums import (
BatchOption,
ResizeMode,
)
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger
from scripts.supported_preprocessor import Preprocessor
@ -101,44 +101,6 @@ def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img):
global_state.update_cn_models()
def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
if image is None:
return None
if isinstance(image, (tuple, list)):
image = {'image': image[0], 'mask': image[1]}
elif not isinstance(image, dict):
image = {'image': image, 'mask': None}
else: # type(image) is dict
# copy to enable modifying the dict and prevent response serialization error
image = dict(image)
if isinstance(image['image'], str):
if os.path.exists(image['image']):
image['image'] = np.array(Image.open(image['image'])).astype('uint8')
elif image['image']:
image['image'] = external_code.to_base64_nparray(image['image'])
else:
image['image'] = None
# If there is no image, return image with None image and None mask
if image['image'] is None:
image['mask'] = None
return image
if 'mask' not in image or image['mask'] is None:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
elif isinstance(image['mask'], str):
if os.path.exists(image['mask']):
image['mask'] = np.array(Image.open(image['mask']).convert("RGB")).astype('uint8')
elif image['mask']:
image['mask'] = external_code.to_base64_nparray(image['mask'])
else:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
return image
def prepare_mask(
mask: Image.Image, p: processing.StableDiffusionProcessing
) -> Image.Image:
@ -242,7 +204,7 @@ def get_control(
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
if isinstance(input_image, list):
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
assert unit.accepts_multiple_inputs or unit.is_animate_diff_batch
input_images = input_image
else: # Following operations are only for single input image.
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
@ -355,21 +317,8 @@ class Script(scripts.Script, metaclass=(
def show(self, is_img2img):
return scripts.AlwaysVisible
@staticmethod
def get_default_ui_unit(is_ui=True):
cls = UiControlNetUnit if is_ui else ControlNetUnit
return cls(
enabled=False,
module="none",
model="None"
)
def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]:
group = ControlNetUiGroup(
is_img2img,
Script.get_default_ui_unit(),
photopea,
)
group = ControlNetUiGroup(is_img2img, photopea)
return group, group.render(tabname, elem_id_tabname)
def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str):
@ -665,10 +614,31 @@ class Script(scripts.Script, metaclass=(
@staticmethod
def get_enabled_units(p):
def unfold_merged(unit: ControlNetUnit) -> List[ControlNetUnit]:
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
preprocessors that can accept multiple input images.
"""
if unit.input_mode != InputMode.MERGE:
return [unit]
if unit.accepts_multiple_inputs:
unit.input_mode = InputMode.SIMPLE
return [unit]
assert isinstance(unit.image, list)
result = []
for image in unit.image:
u = unit.copy()
u.image = [image]
u.input_mode = InputMode.SIMPLE
u.weight = unit.weight / len(unit.image)
result.append(u)
return result
units = external_code.get_all_units_in_processing(p)
if len(units) == 0:
# fill a null group
remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0)
remote_unit = Script.parse_remote_call(p, ControlNetUnit(), 0)
if remote_unit.enabled:
units.append(remote_unit)
@ -677,11 +647,7 @@ class Script(scripts.Script, metaclass=(
local_unit = Script.parse_remote_call(p, unit, idx)
if not local_unit.enabled:
continue
if hasattr(local_unit, "unfold_merged"):
enabled_units.extend(local_unit.unfold_merged())
else:
enabled_units.append(copy(local_unit))
enabled_units.extend(unfold_merged(local_unit))
Infotext.write_infotext(enabled_units, p)
return enabled_units
@ -695,43 +661,31 @@ class Script(scripts.Script, metaclass=(
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
- unit.image: ControlNet tab input image.
- p.init_images: A1111 img2img tab input image.
- unit.image: ControlNet unit input image.
- p.init_images: A1111 img2img input image.
Returns:
- The input image in ndarray form.
- The resize mode.
"""
def parse_unit_image(unit: ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
isinstance(unit.image, list) and
len(unit.image) > 0 and
"image" in unit.image[0]
)
if unit_has_multiple_images:
return [
d
for img in unit.image
for d in (image_dict_from_any(img),)
if d is not None
]
return image_dict_from_any(unit.image)
def decode_image(img) -> np.ndarray:
"""Need to check the image for API compatibility."""
if isinstance(img, str):
return np.asarray(decode_base64_to_image(image['image']))
else:
assert isinstance(img, np.ndarray)
return img
def from_rgba_to_input(img: np.ndarray) -> np.ndarray:
if (
shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) or
(img[:, :, 3] <= 5).all() or
(img[:, :, 3] >= 250).all()
):
# Take RGB
return img[:, :, :3]
logger.info("Canvas scribble mode. Using mask scribble as input.")
return HWC3(img[:, :, 3])
# 4 input image sources.
p_image_control = getattr(p, "image_control", None)
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
image = parse_unit_image(unit)
image = unit.get_input_images_rgba()
a1111_image = getattr(p, "init_images", [None])[0]
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
resize_mode = unit.resize_mode
if batch_hijack.instance.is_batch and p_image_control is not None:
logger.warning("Warn: Using legacy field 'p.image_control'.")
@ -744,42 +698,18 @@ class Script(scripts.Script, metaclass=(
input_image = np.concatenate([color, alpha], axis=2)
else:
input_image = HWC3(np.asarray(p_input_image))
elif image:
if isinstance(image, list):
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
input_image = [HWC3(decode_image(img['image'])) for img in image]
if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
for idx in range(len(input_image)):
while len(image[idx]['mask'].shape) < 3:
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
if unit.is_inpaint or unit.uses_clip:
color = HWC3(image[idx]["image"])
alpha = image[idx]['mask'][:, :, 0:1]
input_image[idx] = np.concatenate([color, alpha], axis=2)
elif image is not None:
assert isinstance(image, list)
# Inpaint mask or CLIP mask.
if unit.is_inpaint or unit.uses_clip:
# RGBA
input_image = image
else:
input_image = HWC3(decode_image(image['image']))
if 'mask' in image and image['mask'] is not None:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]
if unit.is_inpaint or unit.uses_clip:
logger.info("using mask")
color = HWC3(image['image'])
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
elif (
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
# There is wield gradio issue that would produce mask that is
# not pure color when no scribble is made on canvas.
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
)
):
logger.info("using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
unit.module = 'none' # Always use black bg and white line
# RGB
input_image = [from_rgba_to_input(img) for img in image]
if len(input_image) == 1:
input_image = input_image[0]
elif a1111_image is not None:
input_image = HWC3(np.asarray(a1111_image))
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
@ -957,7 +887,6 @@ class Script(scripts.Script, metaclass=(
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
for idx, unit in enumerate(self.enabled_units):
unit.bound_check_params()
Script.check_sd_version_compatible(unit)
if (
'inpaint_only' == unit.module and
@ -1011,7 +940,7 @@ class Script(scripts.Script, metaclass=(
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
if unit.accepts_multiple_inputs():
if unit.accepts_multiple_inputs:
ip_adapter_image_emb_cond = []
model_net.ipadapter.image_proj_model.to(torch.float32) # noqa
for c in cc:
@ -1038,7 +967,7 @@ class Script(scripts.Script, metaclass=(
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
logger.info(f"\t{frame_idx}: {frame_path}")
c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu()
elif unit.accepts_multiple_inputs():
elif unit.accepts_multiple_inputs:
# ip-adapter should do prompt travel
logger.info("IP-Adapter: control prompts will be traveled in the following way:")
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
@ -1067,7 +996,7 @@ class Script(scripts.Script, metaclass=(
c_full[cn_ad_keyframe_idx] = c
c = c_full
# handle batch condition and unconditional
if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs():
if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs:
c = torch.cat([c, c], dim=0)
return c
@ -1090,7 +1019,6 @@ class Script(scripts.Script, metaclass=(
control_model_type.is_controlnet and
model_net.control_model.global_average_pooling
)
control_mode = external_code.control_mode_from_value(unit.control_mode)
forward_param = ControlParams(
control_model=model_net,
preprocessor=preprocessor_dict,
@ -1103,9 +1031,9 @@ class Script(scripts.Script, metaclass=(
control_model_type=control_model_type,
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_control,
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
soft_injection=control_mode != ControlMode.BALANCED,
cfg_injection=control_mode == ControlMode.CONTROL,
hr_option=unit.hr_option if high_res_fix else HiResFixOption.BOTH,
soft_injection=unit.control_mode != ControlMode.BALANCED,
cfg_injection=unit.control_mode == ControlMode.CONTROL,
effective_region_mask=(
get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :]
if unit.effective_region_mask is not None
@ -1217,16 +1145,10 @@ class Script(scripts.Script, metaclass=(
weight = param.weight
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
# TODO: Fix all enum issue
if unit.pulid_mode == "PuLIDMode.FIDELITY":
pulid_mode = PuLIDMode.FIDELITY
else:
pulid_mode = PuLIDMode(unit.pulid_mode)
if pulid_mode == PuLIDMode.STYLE:
if unit.pulid_mode == PuLIDMode.STYLE:
pulid_attn_setting = PULID_SETTING_STYLE
else:
assert pulid_mode == PuLIDMode.FIDELITY
assert unit.pulid_mode == PuLIDMode.FIDELITY
pulid_attn_setting = PULID_SETTING_FIDELITY
param.control_model.hook(
@ -1377,7 +1299,7 @@ class Script(scripts.Script, metaclass=(
unit.batch_images = iter([batch[unit_i] for batch in batches])
def batch_tab_process_each(self, p, *args, **kwargs):
for unit_i, unit in enumerate(self.enabled_units):
for unit in self.enabled_units:
if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0:
continue

View File

@ -1,8 +1,8 @@
import json
import gradio as gr
import functools
from copy import copy
from typing import List, Optional, Union, Dict, Tuple, Literal
import itertools
from typing import List, Optional, Union, Dict, Tuple, Literal, Any
from dataclasses import dataclass
import numpy as np
@ -16,7 +16,6 @@ from annotator.util import HWC3
from internal_controlnet.external_code import ControlNetUnit
from scripts.logging import logger
from scripts.controlnet_ui.openpose_editor import OpenposeEditor
from scripts.controlnet_ui.preset import ControlNetPresetUI
from scripts.controlnet_ui.photopea import Photopea
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
from scripts.enums import (
@ -128,66 +127,39 @@ class A1111Context:
)
class UiControlNetUnit(ControlNetUnit):
"""The data class that stores all states of a ControlNetUnit."""
def create_ui_unit(
input_mode: InputMode = InputMode.SIMPLE,
batch_images: Optional[Any] = None,
output_dir: str = "",
loopback: bool = False,
merge_gallery_files: List[Dict[Union[Literal["name"], Literal["data"]], str]] = [],
use_preview_as_input: bool = False,
generated_image: Optional[np.ndarray] = None,
*args,
) -> ControlNetUnit:
unit_dict = {
k: v
for k, v in zip(
vars(ControlNetUnit()).keys(),
itertools.chain(
[True, input_mode, batch_images, output_dir, loopback], args
),
)
}
def __init__(
self,
input_mode: InputMode = InputMode.SIMPLE,
batch_images: Optional[Union[str, List[external_code.InputImage]]] = None,
output_dir: str = "",
loopback: bool = False,
merge_gallery_files: List[
Dict[Union[Literal["name"], Literal["data"]], str]
] = [],
use_preview_as_input: bool = False,
generated_image: Optional[np.ndarray] = None,
enabled: bool = True,
module: Optional[str] = None,
model: Optional[str] = None,
weight: float = 1.0,
image: Optional[Dict[str, np.ndarray]] = None,
*args,
**kwargs,
):
if use_preview_as_input and generated_image is not None:
input_image = generated_image
module = "none"
else:
input_image = image
if use_preview_as_input and generated_image is not None:
input_image = generated_image
unit_dict["module"] = "none"
else:
input_image = unit_dict["image"]
if merge_gallery_files and input_mode == InputMode.MERGE:
input_image = [
{"image": read_image(file["name"])} for file in merge_gallery_files
]
if merge_gallery_files and input_mode == InputMode.MERGE:
input_image = [
{"image": read_image(file["name"])} for file in merge_gallery_files
]
super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
self.is_ui = True
self.input_mode = input_mode
self.batch_images = batch_images
self.output_dir = output_dir
self.loopback = loopback
def unfold_merged(self) -> List[ControlNetUnit]:
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
preprocessors that can accept multiple input images.
"""
if self.input_mode != InputMode.MERGE:
return [copy(self)]
if self.accepts_multiple_inputs():
self.input_mode = InputMode.SIMPLE
return [copy(self)]
assert isinstance(self.image, list)
result = []
for image in self.image:
unit = copy(self)
unit.image = image["image"]
unit.input_mode = InputMode.SIMPLE
unit.weight = self.weight / len(self.image)
result.append(unit)
return result
unit_dict["image"] = input_image
return ControlNetUnit.from_dict(unit_dict)
class ControlNetUiGroup(object):
@ -221,7 +193,6 @@ class ControlNetUiGroup(object):
def __init__(
self,
is_img2img: bool,
default_unit: ControlNetUnit,
photopea: Optional[Photopea],
):
# Whether callbacks have been registered.
@ -230,13 +201,13 @@ class ControlNetUiGroup(object):
self.ui_initialized: bool = False
self.is_img2img = is_img2img
self.default_unit = default_unit
self.default_unit = ControlNetUnit()
self.photopea = photopea
self.webcam_enabled = False
self.webcam_mirrored = False
# Note: All gradio elements declared in `render` will be defined as member variable.
# Update counter to trigger a force update of UiControlNetUnit.
# Update counter to trigger a force update of ControlNetUnit.
# This is useful when a field with no event subscriber available changes.
# e.g. gr.Gallery, gr.State, etc.
self.update_unit_counter = None
@ -283,7 +254,6 @@ class ControlNetUiGroup(object):
self.loopback = None
self.use_preview_as_input = None
self.openpose_editor = None
self.preset_panel = None
self.upload_independent_img_in_img2img = None
self.image_upload_panel = None
self.save_detected_map = None
@ -330,11 +300,13 @@ class ControlNetUiGroup(object):
tool="sketch",
elem_id=f"{elem_id_tabname}_{tabname}_input_image",
elem_classes=["cnet-image"],
brush_color=shared.opts.img2img_inpaint_mask_brush_color
if hasattr(
shared.opts, "img2img_inpaint_mask_brush_color"
)
else None,
brush_color=(
shared.opts.img2img_inpaint_mask_brush_color
if hasattr(
shared.opts, "img2img_inpaint_mask_brush_color"
)
else None
),
)
self.image.preprocess = functools.partial(
svg_preprocess, preprocess=self.image.preprocess
@ -515,7 +487,11 @@ class ControlNetUiGroup(object):
)
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
self.type_filter = (gr.Dropdown if shared.opts.data.get("controlnet_control_type_dropdown", False) else gr.Radio)(
self.type_filter = (
gr.Dropdown
if shared.opts.data.get("controlnet_control_type_dropdown", False)
else gr.Radio
)(
Preprocessor.get_all_preprocessor_tags(),
label="Control Type",
value="All",
@ -645,7 +621,7 @@ class ControlNetUiGroup(object):
self.loopback = gr.Checkbox(
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
value=self.default_unit.loopback,
value=False,
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
elem_classes="controlnet_loopback_checkbox",
visible=False,
@ -653,10 +629,6 @@ class ControlNetUiGroup(object):
self.advanced_weight_control.render()
self.preset_panel = ControlNetPresetUI(
id_prefix=f"{elem_id_tabname}_{tabname}_"
)
self.batch_image_dir_state = gr.State("")
self.output_dir_state = gr.State("")
unit_args = (
@ -693,32 +665,13 @@ class ControlNetUiGroup(object):
self.pulid_mode,
)
unit = gr.State(self.default_unit)
for comp in unit_args + (self.update_unit_counter,):
event_subscribers = []
if hasattr(comp, "edit"):
event_subscribers.append(comp.edit)
elif hasattr(comp, "click"):
event_subscribers.append(comp.click)
elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
event_subscribers.append(comp.release)
elif hasattr(comp, "change"):
event_subscribers.append(comp.change)
if hasattr(comp, "clear"):
event_subscribers.append(comp.clear)
for event_subscriber in event_subscribers:
event_subscriber(
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
)
unit = gr.State(ControlNetUnit())
(
ControlNetUiGroup.a1111_context.img2img_submit_button
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_submit_button
).click(
fn=UiControlNetUnit,
fn=create_ui_unit,
inputs=list(unit_args),
outputs=unit,
queue=False,
@ -803,10 +756,12 @@ class ControlNetUiGroup(object):
def register_build_sliders(self):
def build_sliders(module: str, pp: bool):
preprocessor = Preprocessor.get_preprocessor(module)
slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy()
slider_resolution_kwargs = (
preprocessor.slider_resolution.gradio_update_kwargs.copy()
)
if pp:
slider_resolution_kwargs['visible'] = False
slider_resolution_kwargs["visible"] = False
grs = [
gr.update(**slider_resolution_kwargs),
@ -852,9 +807,7 @@ class ControlNetUiGroup(object):
gr.Dropdown.update(
value=default_option, choices=filtered_preprocessor_list
),
gr.Dropdown.update(
value=default_model, choices=filtered_model_list
),
gr.Dropdown.update(value=default_model, choices=filtered_model_list),
]
self.type_filter.change(
@ -893,7 +846,9 @@ class ControlNetUiGroup(object):
)
def register_run_annotator(self):
def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str):
def run_annotator(
image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str
):
if image is None:
return (
gr.update(value=None, visible=True),
@ -956,16 +911,16 @@ class ControlNetUiGroup(object):
and shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
),
json_pose_callback=(
json_acceptor.accept
if is_openpose(module)
else None
json_acceptor.accept if is_openpose(module) else None
),
model=model,
)
return (
# Update to `generated_image`
gr.update(value=result.display_images[0], visible=True, interactive=False),
gr.update(
value=result.display_images[0], visible=True, interactive=False
),
# preprocessor_preview
gr.update(value=True),
# openpose editor
@ -980,12 +935,16 @@ class ControlNetUiGroup(object):
self.processor_res,
self.threshold_a,
self.threshold_b,
ControlNetUiGroup.a1111_context.img2img_w_slider
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_w_slider,
ControlNetUiGroup.a1111_context.img2img_h_slider
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_h_slider,
(
ControlNetUiGroup.a1111_context.img2img_w_slider
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_w_slider
),
(
ControlNetUiGroup.a1111_context.img2img_h_slider
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_h_slider
),
self.pixel_perfect,
self.resize_mode,
self.model,
@ -1256,14 +1215,6 @@ class ControlNetUiGroup(object):
self.model,
)
assert self.type_filter is not None
self.preset_panel.register_callbacks(
self,
self.type_filter,
*[
getattr(self, key)
for key in vars(ControlNetUnit()).keys()
],
)
self.advanced_weight_control.register_callbacks(
self.weight,
self.advanced_weighting,

View File

@ -1,305 +0,0 @@
import os
import gradio as gr
from typing import Dict, List
from modules import scripts
from modules.ui_components import ToolButton
from internal_controlnet.external_code import ControlNetUnit
from scripts.infotext import parse_unit, serialize_unit
from scripts.logging import logger
from scripts.supported_preprocessor import Preprocessor
save_symbol = "\U0001f4be" # 💾
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
refresh_symbol = "\U0001f504" # 🔄
reset_symbol = "\U000021A9" # ↩
NEW_PRESET = "New Preset"
def load_presets(preset_dir: str) -> Dict[str, str]:
if not os.path.exists(preset_dir):
os.makedirs(preset_dir)
return {}
presets = {}
for filename in os.listdir(preset_dir):
if filename.endswith(".txt"):
with open(os.path.join(preset_dir, filename), "r") as f:
name = filename.replace(".txt", "")
if name == NEW_PRESET:
continue
presets[name] = f.read()
return presets
def infer_control_type(module: str, model: str) -> str:
p = Preprocessor.get_preprocessor(module)
assert p is not None
matched_tags = [
tag
for tag in p.tags
if any(f in model.lower() for f in Preprocessor.tag_to_filters(tag))
]
if len(matched_tags) != 1:
raise ValueError(
f"Unable to infer control type from module {module} and model {model}"
)
return matched_tags[0]
class ControlNetPresetUI(object):
preset_directory = os.path.join(scripts.basedir(), "presets")
presets = load_presets(preset_directory)
def __init__(self, id_prefix: str):
with gr.Row():
self.dropdown = gr.Dropdown(
label="Presets",
show_label=True,
elem_classes=["cnet-preset-dropdown"],
choices=ControlNetPresetUI.dropdown_choices(),
value=NEW_PRESET,
)
self.reset_button = ToolButton(
value=reset_symbol,
elem_classes=["cnet-preset-reset"],
tooltip="Reset preset",
visible=False,
)
self.save_button = ToolButton(
value=save_symbol,
elem_classes=["cnet-preset-save"],
tooltip="Save preset",
)
self.delete_button = ToolButton(
value=delete_symbol,
elem_classes=["cnet-preset-delete"],
tooltip="Delete preset",
)
self.refresh_button = ToolButton(
value=refresh_symbol,
elem_classes=["cnet-preset-refresh"],
tooltip="Refresh preset",
)
with gr.Box(
elem_classes=["popup-dialog", "cnet-preset-enter-name"],
elem_id=f"{id_prefix}_cnet_preset_enter_name",
) as self.name_dialog:
with gr.Row():
self.preset_name = gr.Textbox(
label="Preset name",
show_label=True,
lines=1,
elem_classes=["cnet-preset-name"],
)
self.confirm_preset_name = ToolButton(
value=save_symbol,
elem_classes=["cnet-preset-confirm-name"],
tooltip="Save preset",
)
def register_callbacks(
self,
uigroup,
control_type: gr.Radio,
*ui_states,
):
def apply_preset(name: str, control_type: str, *ui_states):
if name == NEW_PRESET:
return (
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(ControlNetUnit()).keys()) + 1)
),
)
assert name in ControlNetPresetUI.presets
infotext = ControlNetPresetUI.presets[name]
preset_unit = parse_unit(infotext)
current_unit = ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None
# Do not compare module param that are not used in preset.
for module_param in ("processor_res", "threshold_a", "threshold_b"):
if getattr(preset_unit, module_param) == -1:
setattr(current_unit, module_param, -1)
# No update necessary.
if vars(current_unit) == vars(preset_unit):
return (
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(ControlNetUnit()).keys()) + 1)
),
)
unit = preset_unit
try:
new_control_type = infer_control_type(unit.module, unit.model)
except ValueError as e:
logger.error(e)
new_control_type = control_type
return (
gr.update(visible=True),
gr.update(value=new_control_type),
*[
gr.update(value=value) if value is not None else gr.update()
for value in vars(unit).values()
],
)
for element, action in (
(self.dropdown, "change"),
(self.reset_button, "click"),
):
getattr(element, action)(
fn=apply_preset,
inputs=[self.dropdown, control_type, *ui_states],
outputs=[self.delete_button, control_type, *ui_states],
show_progress="hidden",
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=[self.reset_button],
)
def save_preset(name: str, *ui_states):
if name == NEW_PRESET:
return gr.update(visible=True), gr.update(), gr.update()
ControlNetPresetUI.save_preset(
name, ControlNetUnit(*ui_states)
)
return (
gr.update(), # name dialog
gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name),
gr.update(visible=False), # Reset button
)
self.save_button.click(
fn=save_preset,
inputs=[self.dropdown, *ui_states],
outputs=[self.name_dialog, self.dropdown, self.reset_button],
show_progress="hidden",
).then(
fn=None,
_js=f"""
(name) => {{
if (name === "{NEW_PRESET}")
popup(gradioApp().getElementById('{self.name_dialog.elem_id}'));
}}""",
inputs=[self.dropdown],
)
def delete_preset(name: str):
ControlNetPresetUI.delete_preset(name)
return gr.Dropdown.update(
choices=ControlNetPresetUI.dropdown_choices(),
value=NEW_PRESET,
), gr.update(visible=False)
self.delete_button.click(
fn=delete_preset,
inputs=[self.dropdown],
outputs=[self.dropdown, self.reset_button],
show_progress="hidden",
)
self.name_dialog.visible = False
def save_new_preset(new_name: str, *ui_states):
if new_name == NEW_PRESET:
logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'")
return gr.update(visible=False), gr.update()
ControlNetPresetUI.save_preset(
new_name, ControlNetUnit(*ui_states)
)
return gr.update(visible=False), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
)
self.confirm_preset_name.click(
fn=save_new_preset,
inputs=[self.preset_name, *ui_states],
outputs=[self.name_dialog, self.dropdown],
show_progress="hidden",
).then(fn=None, _js="closePopup")
self.refresh_button.click(
fn=ControlNetPresetUI.refresh_preset,
inputs=None,
outputs=[self.dropdown],
show_progress="hidden",
)
def update_reset_button(preset_name: str, *ui_states):
if preset_name == NEW_PRESET:
return gr.update(visible=False)
infotext = ControlNetPresetUI.presets[preset_name]
preset_unit = parse_unit(infotext)
current_unit = ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None
# Do not compare module param that are not used in preset.
for module_param in ("processor_res", "threshold_a", "threshold_b"):
if getattr(preset_unit, module_param) == -1:
setattr(current_unit, module_param, -1)
return gr.update(visible=vars(current_unit) != vars(preset_unit))
for ui_state in ui_states:
if isinstance(ui_state, gr.Image):
continue
for action in ("edit", "click", "change", "clear", "release"):
if action == "release" and not isinstance(ui_state, gr.Slider):
continue
if hasattr(ui_state, action):
getattr(ui_state, action)(
fn=update_reset_button,
inputs=[self.dropdown, *ui_states],
outputs=[self.reset_button],
)
@staticmethod
def dropdown_choices() -> List[str]:
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]
@staticmethod
def save_preset(name: str, unit: ControlNetUnit):
infotext = serialize_unit(unit)
with open(
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
) as f:
f.write(infotext)
ControlNetPresetUI.presets[name] = infotext
@staticmethod
def delete_preset(name: str):
if name not in ControlNetPresetUI.presets:
return
del ControlNetPresetUI.presets[name]
file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt")
if os.path.exists(file):
os.unlink(file)
@staticmethod
def refresh_preset():
ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory)
return gr.update(choices=ControlNetPresetUI.dropdown_choices())

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, List, NamedTuple
from typing import List, NamedTuple
from functools import lru_cache
@ -224,19 +224,6 @@ class HiResFixOption(Enum):
LOW_RES_ONLY = "Low res only"
HIGH_RES_ONLY = "High res only"
@staticmethod
def from_value(value: Any) -> "HiResFixOption":
if isinstance(value, str) and value.startswith("HiResFixOption."):
_, field = value.split(".")
return getattr(HiResFixOption, field)
if isinstance(value, str):
return HiResFixOption(value)
elif isinstance(value, int):
return [x for x in HiResFixOption][value]
else:
assert isinstance(value, HiResFixOption)
return value
class InputMode(Enum):
# Single image to a single ControlNet unit.

View File

@ -1,5 +1,5 @@
from typing import List, Tuple, Union
from typing import List, Tuple
from enum import Enum
import gradio as gr
from modules.processing import StableDiffusionProcessing
@ -8,53 +8,6 @@ from internal_controlnet.external_code import ControlNetUnit
from scripts.logging import logger
def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])
def displaytext_to_field(text: str) -> str:
return "_".join([word.lower() for word in text.split(" ")])
def parse_value(value: str) -> Union[str, float, int, bool]:
if value in ("True", "False"):
return value == "True"
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value # Plain string.
def serialize_unit(unit: ControlNetUnit) -> str:
excluded_fields = ControlNetUnit.infotext_excluded_fields()
log_value = {
field_to_displaytext(field): getattr(unit, field)
for field in vars(ControlNetUnit()).keys()
if field not in excluded_fields and getattr(unit, field) != -1
# Note: exclude hidden slider values.
}
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
logger.error(f"Unexpected tokens encountered:\n{log_value}")
return ""
return ", ".join(f"{field}: {value}" for field, value in log_value.items())
def parse_unit(text: str) -> ControlNetUnit:
return ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
for item in text.split(",")
for (key, value) in (item.strip().split(": "),)
},
)
class Infotext(object):
def __init__(self) -> None:
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
@ -74,11 +27,7 @@ class Infotext(object):
iocomponents.
"""
unit_prefix = Infotext.unit_prefix(unit_index)
for field in vars(ControlNetUnit()).keys():
# Exclude image for infotext.
if field == "image":
continue
for field in ControlNetUnit.infotext_fields():
# Every field in ControlNetUnit should have a cooresponding
# IOComponent in ControlNetUiGroup.
io_component = getattr(uigroup, field)
@ -87,13 +36,11 @@ class Infotext(object):
self.paste_field_names.append(component_locator)
@staticmethod
def write_infotext(
units: List[ControlNetUnit], p: StableDiffusionProcessing
):
def write_infotext(units: List[ControlNetUnit], p: StableDiffusionProcessing):
"""Write infotext to `p`."""
p.extra_generation_params.update(
{
Infotext.unit_prefix(i): serialize_unit(unit)
Infotext.unit_prefix(i): unit.serialize()
for i, unit in enumerate(units)
if unit.enabled
}
@ -109,14 +56,19 @@ class Infotext(object):
assert isinstance(v, str), f"Expect string but got {v}."
try:
for field, value in vars(parse_unit(v)).items():
if field == "image":
for field, value in vars(ControlNetUnit.parse(v)).items():
if field not in ControlNetUnit.infotext_fields():
continue
if value is None:
logger.debug(f"InfoText: Skipping {field} because value is None.")
logger.debug(
f"InfoText: Skipping {field} because value is None."
)
continue
component_locator = f"{k} {field}"
if isinstance(value, Enum):
value = value.value
updates[component_locator] = value
logger.debug(f"InfoText: Setting {component_locator} = {value}")
except Exception as e:

View File

@ -1,3 +1,4 @@
import numpy as np
import unittest.mock
import importlib
from typing import Any
@ -7,13 +8,17 @@ utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'u
from modules import processing, scripts, shared
from internal_controlnet.external_code import ControlNetUnit
from scripts import controlnet, external_code, batch_hijack
from scripts import controlnet, batch_hijack
batch_hijack.instance.undo_hijack()
original_process_images_inner = processing.process_images_inner
def create_unit(**kwargs) -> ControlNetUnit:
return ControlNetUnit(enabled=True, **kwargs)
class TestBatchHijack(unittest.TestCase):
@unittest.mock.patch('modules.script_callbacks.on_script_unloaded')
def setUp(self, on_script_unloaded_mock):
@ -59,9 +64,18 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
is_cn_batch, batches, output_dir, _ = batch_hijack.get_cn_batches(self.p)
batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.process_batch_callbacks, self.p, batches, output_dir)
batch_units = [unit for unit in self.p.script_args if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH]
batch_units = [
unit
for unit in self.p.script_args
if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH
]
# Convert iterator to list to avoid double eval of iterator exhausting
# the iterator in following checks.
for unit in batch_units:
unit.batch_images = list(unit.batch_images)
if batch_units:
self.assertEqual(min(len(list(unit.batch_images)) for unit in batch_units), len(batches))
self.assertEqual(min(len(unit.batch_images) for unit in batch_units), len(batches))
else:
self.assertEqual(1, len(batches))
@ -74,15 +88,15 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
self.assertEqual(is_batch, False)
def test_get_cn_batches__1_simple(self):
self.p.script_args.append(ControlNetUnit(image=get_dummy_image()))
self.p.script_args.append(create_unit(image=get_dummy_image()))
self.assert_get_cn_batches_works([
[self.p.script_args[0].image],
[get_dummy_image()],
])
def test_get_cn_batches__2_simples(self):
self.p.script_args.extend([
ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(1)),
create_unit(image=get_dummy_image(0)),
create_unit(image=get_dummy_image(1)),
])
self.assert_get_cn_batches_works([
[get_dummy_image(0)],
@ -91,7 +105,7 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
def test_get_cn_batches__1_batch(self):
self.p.script_args.extend([
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
@ -108,14 +122,14 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
def test_get_cn_batches__2_batches(self):
self.p.script_args.extend([
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
get_dummy_image(1),
],
),
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(2),
@ -136,8 +150,8 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
def test_get_cn_batches__2_mixed(self):
self.p.script_args.extend([
ControlNetUnit(image=get_dummy_image(0)),
controlnet.UiControlNetUnit(
create_unit(image=get_dummy_image(0)),
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(1),
@ -158,8 +172,8 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
def test_get_cn_batches__3_mixed(self):
self.p.script_args.extend([
ControlNetUnit(image=get_dummy_image(0)),
controlnet.UiControlNetUnit(
create_unit(image=get_dummy_image(0)),
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(1),
@ -167,7 +181,7 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
get_dummy_image(3),
],
),
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(4),
@ -243,14 +257,14 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
def test_process_images__only_simple_units__forwards(self):
self.p.script_args = [
ControlNetUnit(image=get_dummy_image()),
ControlNetUnit(image=get_dummy_image()),
create_unit(image=get_dummy_image()),
create_unit(image=get_dummy_image()),
]
self.assert_process_images_hijack_called(batch_count=0)
def test_process_images__1_batch_1_unit__runs_1_batch(self):
self.p.script_args = [
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(),
@ -261,7 +275,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
def test_process_images__2_batches_1_unit__runs_2_batches(self):
self.p.script_args = [
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
@ -274,7 +288,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
def test_process_images__8_batches_1_unit__runs_8_batches(self):
batch_count = 8
self.p.script_args = [
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[get_dummy_image(i) for i in range(batch_count)]
),
@ -283,11 +297,11 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
def test_process_images__1_batch_2_units__runs_1_batch(self):
self.p.script_args = [
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[get_dummy_image(0)]
),
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[get_dummy_image(1)]
),
@ -296,14 +310,14 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
def test_process_images__2_batches_2_units__runs_2_batches(self):
self.p.script_args = [
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
get_dummy_image(1),
],
),
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(2),
@ -315,7 +329,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
def test_process_images__3_batches_2_mixed_units__runs_3_batches(self):
self.p.script_args = [
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
@ -323,7 +337,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
get_dummy_image(2),
],
),
controlnet.UiControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.SIMPLE,
image=get_dummy_image(3),
),

View File

@ -7,7 +7,6 @@ import importlib
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
from scripts import external_code
from scripts.enums import ResizeMode
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
from internal_controlnet.external_code import ControlNetUnit
@ -119,9 +118,7 @@ class TestScript(unittest.TestCase):
"AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg=="
)
sample_np_image = np.array(
[[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8
)
sample_np_image = np.zeros(shape=[8, 8, 3], dtype=np.uint8)
def test_choose_input_image(self):
with self.subTest(name="no image"):
@ -139,7 +136,7 @@ class TestScript(unittest.TestCase):
resize_mode=ResizeMode.OUTER_FIT,
),
unit=ControlNetUnit(
image=TestScript.sample_base64_image,
image=TestScript.sample_np_image,
module="none",
resize_mode=ResizeMode.INNER_FIT,
),

View File

@ -1,34 +0,0 @@
import unittest
import importlib
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
from scripts.infotext import parse_unit
from scripts.external_code import ControlNetUnit
class TestInfotext(unittest.TestCase):
def test_parsing(self):
infotext = (
"Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, "
"Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, "
"Pixel Perfect: True, Control Mode: Balanced, Hr Option: Both, Save Detected Map: True"
)
self.assertEqual(
vars(
ControlNetUnit(
module="inpaint_only+lama",
model="control_v11p_sd15_inpaint [ebff9138]",
weight=1,
resize_mode="Resize and Fill",
low_vram=False,
guidance_start=0,
guidance_end=1,
pixel_perfect=True,
control_mode="Balanced",
hr_option="Both",
save_detected_map=True,
)
),
vars(parse_unit(infotext)),
)

View File

@ -54,70 +54,6 @@ class TestExternalCodeWorking(unittest.TestCase):
self.assert_update_in_place_ok()
class TestControlNetUnitConversion(unittest.TestCase):
def setUp(self):
self.dummy_image = 'base64...'
self.input = {}
self.expected = ControlNetUnit()
def assert_converts_to_expected(self):
self.assertEqual(vars(external_code.to_processing_unit(self.input)), vars(self.expected))
def test_empty_dict_works(self):
self.assert_converts_to_expected()
def test_image_works(self):
self.input = {
'image': self.dummy_image
}
self.expected = ControlNetUnit(image=self.dummy_image)
self.assert_converts_to_expected()
def test_image_alias_works(self):
self.input = {
'input_image': self.dummy_image
}
self.expected = ControlNetUnit(image=self.dummy_image)
self.assert_converts_to_expected()
def test_masked_image_works(self):
self.input = {
'image': self.dummy_image,
'mask': self.dummy_image,
}
self.expected = ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image})
self.assert_converts_to_expected()
class TestControlNetUnitImageToDict(unittest.TestCase):
def setUp(self):
self.dummy_image = utils.readImage("test/test_files/img2img_basic.png")
self.input = ControlNetUnit()
self.expected_image = external_code.to_base64_nparray(self.dummy_image)
self.expected_mask = external_code.to_base64_nparray(self.dummy_image)
def assert_dict_is_valid(self):
actual_dict = controlnet.image_dict_from_any(self.input.image)
self.assertEqual(actual_dict['image'].tolist(), self.expected_image.tolist())
self.assertEqual(actual_dict['mask'].tolist(), self.expected_mask.tolist())
def test_none(self):
self.assertEqual(controlnet.image_dict_from_any(self.input.image), None)
def test_image_without_mask(self):
self.input.image = self.dummy_image
self.expected_mask = np.zeros_like(self.expected_image, dtype=np.uint8)
self.assert_dict_is_valid()
def test_masked_image_tuple(self):
self.input.image = (self.dummy_image, self.dummy_image,)
self.assert_dict_is_valid()
def test_masked_image_dict(self):
self.input.image = {'image': self.dummy_image, 'mask': self.dummy_image}
self.assert_dict_is_valid()
class TestPixelPerfectResolution(unittest.TestCase):
def test_outer_fit(self):
image = np.zeros((100, 100, 3))
@ -136,37 +72,5 @@ class TestPixelPerfectResolution(unittest.TestCase):
self.assertEqual(result, expected)
class TestGetAllUnitsFrom(unittest.TestCase):
def test_none(self):
self.assertListEqual(external_code.get_all_units_from([None]), [])
def test_bool(self):
self.assertListEqual(external_code.get_all_units_from([True]), [])
def test_inheritance(self):
class Foo(ControlNetUnit):
def __init__(self):
super().__init__(self)
self.bar = 'a'
foo = Foo()
self.assertListEqual(external_code.get_all_units_from([foo]), [foo])
def test_dict(self):
units = external_code.get_all_units_from([{}])
self.assertGreater(len(units), 0)
self.assertIsInstance(units[0], ControlNetUnit)
def test_unitlike(self):
class Foo(object):
""" bar """
foo = Foo()
for key in vars(ControlNetUnit()).keys():
setattr(foo, key, True)
setattr(foo, 'bar', False)
self.assertListEqual(external_code.get_all_units_from([foo]), [foo])
if __name__ == '__main__':
unittest.main()

View File

@ -1,36 +0,0 @@
import unittest
import importlib
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils')
from scripts import external_code
from scripts.enums import ControlMode
from internal_controlnet.external_code import ControlNetUnit
class TestGetAllUnitsFrom(unittest.TestCase):
def setUp(self):
self.control_unit = {
"module": "none",
"model": utils.get_model("canny"),
"image": utils.readImage("test/test_files/img2img_basic.png"),
"resize_mode": 1,
"low_vram": False,
"processor_res": 64,
"control_mode": ControlMode.BALANCED.value,
}
self.object_unit = ControlNetUnit(**self.control_unit)
def test_empty_converts(self):
script_args = []
units = external_code.get_all_units_from(script_args)
self.assertListEqual(units, [])
def test_object_forwards(self):
script_args = [self.object_unit]
units = external_code.get_all_units_from(script_args)
self.assertListEqual(units, [self.object_unit])
if __name__ == '__main__':
unittest.main()

View File

@ -143,16 +143,16 @@ def test_detect_default_param():
dict(
controlnet_input_images=[realistic_girl_face_img],
controlnet_module="canny", # Canny does not require model download.
controlnet_threshold_a=-1,
controlnet_threshold_b=-1,
controlnet_processor_res=-1,
controlnet_threshold_a=-100,
controlnet_threshold_b=-100,
controlnet_processor_res=-100,
),
"default_param",
)
assert log_context.is_in_console_logs(
[
"[canny.processor_res] Invalid value(-1), using default value 512.",
"[canny.threshold_a] Invalid value(-1.0), using default value 100.",
"[canny.threshold_b] Invalid value(-1.0), using default value 200.",
"[canny.processor_res] Invalid value(-100), using default value 512.",
"[canny.threshold_a] Invalid value(-100.0), using default value 100.",
"[canny.threshold_b] Invalid value(-100.0), using default value 200.",
]
)

View File

@ -169,14 +169,14 @@ def expect_same_image(img1, img2, diff_img_path: str) -> bool:
default_unit = {
"control_mode": 0,
"control_mode": "Balanced",
"enabled": True,
"guidance_end": 1,
"guidance_start": 0,
"low_vram": False,
"pixel_perfect": True,
"processor_res": 512,
"resize_mode": 1,
"resize_mode": "Crop and Resize",
"threshold_a": -1,
"threshold_b": -1,
"weight": 1,

View File

@ -87,12 +87,13 @@ def test_invalid_param(gen_type, param_name):
f"test_invalid_param{(gen_type, param_name)}",
gen_type,
payload_overrides={},
unit_overrides={param_name: -1},
unit_overrides={param_name: -100},
input_image=girl_img,
).exec()
number = "-100" if param_name == "processor_res" else "-100.0"
assert log_context.is_in_console_logs(
[
f"[canny.{param_name}] Invalid value(-1), using default value",
f"[canny.{param_name}] Invalid value({number}), using default value",
]
)
@ -192,7 +193,7 @@ def test_hr_option():
"enable_hr": True,
"denoising_strength": 0.75,
},
unit_overrides={"hr_option": "HiResFixOption.BOTH"},
unit_overrides={"hr_option": "Both"},
input_image=girl_img,
).exec(expected_output_num=3)
@ -203,7 +204,7 @@ def test_hr_option_default():
"test_hr_option_default",
"txt2img",
payload_overrides={"enable_hr": False},
unit_overrides={"hr_option": "HiResFixOption.BOTH"},
unit_overrides={"hr_option": "Both"},
input_image=girl_img,
).exec(expected_output_num=2)

View File

@ -295,13 +295,13 @@ def get_model(model_name: str) -> str:
default_unit = {
"control_mode": 0,
"control_mode": "Balanced",
"enabled": True,
"guidance_end": 1,
"guidance_start": 0,
"pixel_perfect": True,
"processor_res": 512,
"resize_mode": 1,
"resize_mode": "Crop and Resize",
"threshold_a": -1,
"threshold_b": -1,
"weight": 1,

0
unit_tests/__init__.py Normal file
View File

241
unit_tests/args_test.py Normal file
View File

@ -0,0 +1,241 @@
import pytest
import torch
import numpy as np
from dataclasses import dataclass
from internal_controlnet.args import ControlNetUnit
H = W = 128
img1 = np.ones(shape=[H, W, 3], dtype=np.uint8)
img2 = np.ones(shape=[H, W, 3], dtype=np.uint8) * 2
mask_diff = np.ones(shape=[H - 1, W - 1, 3], dtype=np.uint8) * 2
mask_2d = np.ones(shape=[H, W])
img_bad_channel = np.ones(shape=[H, W, 2], dtype=np.uint8) * 2
img_bad_dim = np.ones(shape=[1, H, W, 3], dtype=np.uint8) * 2
ui_img_diff = np.ones(shape=[H - 1, W - 1, 4], dtype=np.uint8) * 2
ui_img = np.ones(shape=[H, W, 4], dtype=np.uint8)
tensor1 = torch.zeros(size=[1, 1], dtype=torch.float16)
@pytest.fixture(scope="module")
def set_cls_funcs():
ControlNetUnit.cls_match_model = lambda s: s in {
"None",
"model1",
"model2",
"control_v11p_sd15_inpaint [ebff9138]",
}
ControlNetUnit.cls_match_module = lambda s: s in {
"none",
"module1",
"inpaint_only+lama",
}
ControlNetUnit.cls_decode_base64 = lambda s: {
"b64img1": img1,
"b64img2": img2,
"b64mask_diff": mask_diff,
}[s]
ControlNetUnit.cls_torch_load_base64 = lambda s: {
"b64tensor1": tensor1,
}[s]
ControlNetUnit.cls_get_preprocessor = lambda s: {
"module1": MockPreprocessor(),
"none": MockPreprocessor(),
"inpaint_only+lama": MockPreprocessor(),
}[s]
def test_module_invalid(set_cls_funcs):
with pytest.raises(ValueError) as excinfo:
ControlNetUnit(module="foo")
assert "module(foo) not found in supported modules." in str(excinfo.value)
def test_module_valid(set_cls_funcs):
ControlNetUnit(module="module1")
def test_model_invalid(set_cls_funcs):
with pytest.raises(ValueError) as excinfo:
ControlNetUnit(model="foo")
assert "model(foo) not found in supported models." in str(excinfo.value)
def test_model_valid(set_cls_funcs):
ControlNetUnit(model="model1")
@pytest.mark.parametrize(
"d",
[
# API
dict(image={"image": "b64img1"}),
dict(image={"image": "b64img1", "mask": "b64img2"}),
dict(image=["b64img1", "b64img2"]),
dict(image=("b64img1", "b64img2")),
dict(image=[{"image": "b64img1", "mask": "b64img2"}]),
dict(image=[{"image": "b64img1"}]),
dict(image=[{"image": "b64img1", "mask": None}]),
dict(
image=[
{"image": "b64img1", "mask": "b64img2"},
{"image": "b64img1", "mask": "b64img2"},
]
),
dict(
image=[
{"image": "b64img1", "mask": None},
{"image": "b64img1", "mask": "b64img2"},
]
),
dict(
image=[
{"image": "b64img1"},
{"image": "b64img1", "mask": "b64img2"},
]
),
dict(image="b64img1", mask="b64img2"),
dict(image="b64img1"),
dict(image="b64img1", mask_image="b64img2"),
dict(image=None),
# UI
dict(image=dict(image=img1)),
dict(image=dict(image=img1, mask=img2)),
# 2D mask should be accepted.
dict(image=dict(image=img1, mask=mask_2d)),
dict(image=img1, mask=mask_2d),
],
)
def test_valid_image_formats(set_cls_funcs, d):
ControlNetUnit(**d)
unit = ControlNetUnit.from_dict(d)
unit.get_input_images_rgba()
@pytest.mark.parametrize(
"d",
[
dict(image={"mask": "b64img1"}),
dict(image={"foo": "b64img1", "bar": "b64img2"}),
dict(image=["b64img1"]),
dict(image=("b64img1", "b64img2", "b64img1")),
dict(image=[]),
dict(image=[{"mask": "b64img1"}]),
dict(image=None, mask="b64img2"),
# image & mask have different H x W
dict(image="b64img1", mask="b64mask_diff"),
],
)
def test_invalid_image_formats(set_cls_funcs, d):
# Setting field will be fine.
ControlNetUnit(**d)
unit = ControlNetUnit.from_dict(d)
# Error on eval.
with pytest.raises((ValueError, AssertionError)):
unit.get_input_images_rgba()
def test_mask_alias_conflict():
with pytest.raises((ValueError, AssertionError)):
ControlNetUnit.from_dict(
dict(
image="b64img1",
mask="b64img1",
mask_image="b64img1",
)
),
def test_resize_mode():
ControlNetUnit(resize_mode="Just Resize")
def test_weight():
ControlNetUnit(weight=0.5)
ControlNetUnit(weight=0.0)
with pytest.raises(ValueError):
ControlNetUnit(weight=-1)
with pytest.raises(ValueError):
ControlNetUnit(weight=100)
def test_start_end():
ControlNetUnit(guidance_start=0.0, guidance_end=1.0)
ControlNetUnit(guidance_start=0.5, guidance_end=1.0)
ControlNetUnit(guidance_start=0.5, guidance_end=0.5)
with pytest.raises(ValueError):
ControlNetUnit(guidance_start=1.0, guidance_end=0.0)
with pytest.raises(ValueError):
ControlNetUnit(guidance_start=11)
with pytest.raises(ValueError):
ControlNetUnit(guidance_end=11)
def test_effective_region_mask():
ControlNetUnit(effective_region_mask="b64img1")
ControlNetUnit(effective_region_mask=None)
ControlNetUnit(effective_region_mask=img1)
with pytest.raises(ValueError):
ControlNetUnit(effective_region_mask=124)
def test_ipadapter_input():
ControlNetUnit(ipadapter_input=["b64tensor1"])
ControlNetUnit(ipadapter_input="b64tensor1")
ControlNetUnit(ipadapter_input=None)
with pytest.raises(ValueError):
ControlNetUnit(ipadapter_input=[])
@dataclass
class MockSlider:
value: float = 1
minimum: float = 0
maximum: float = 2
@dataclass
class MockPreprocessor:
slider_resolution = MockSlider()
slider_1 = MockSlider()
slider_2 = MockSlider()
def test_preprocessor_sliders():
unit = ControlNetUnit(enabled=True, module="none")
assert unit.processor_res == 1
assert unit.threshold_a == 1
assert unit.threshold_b == 1
def test_preprocessor_sliders_disabled():
unit = ControlNetUnit(enabled=False, module="none")
assert unit.processor_res == -1
assert unit.threshold_a == -1
assert unit.threshold_b == -1
def test_infotext_parsing():
infotext = (
"Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, "
"Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, "
"Pixel Perfect: True, Control Mode: Balanced"
)
assert ControlNetUnit(
enabled=True,
module="inpaint_only+lama",
model="control_v11p_sd15_inpaint [ebff9138]",
weight=1,
resize_mode="Resize and Fill",
low_vram=False,
guidance_start=0,
guidance_end=1,
pixel_perfect=True,
control_mode="Balanced",
) == ControlNetUnit.parse(infotext)