Validate ControlNetUnit using pydantic (#2847)
* Add Pydantic ControlNetUnit Add test config Add images field Adjust image field handling fix various ui Fix most UI issues accept greyscale image/mask Fix infotext Fix preset Fix infotext nit Move infotext parsing test Remove preset Remove unused js code Adjust test payload By default disable unit refresh enum usage Align resize mode change test func name remove unused import nit Change default handling Skip bound check when not enabled Fix batch Various batch fix Disable batch hijack test adjust test fix test expectations Fix unit copy nit Fix test failures * Change script args back to ControlNetUnit for compatibility * import enum for compatibility * Fix unit test * simplify unfold * Add test coverage * handle directly set np image * re-enable batch test * Add back canvas scribble support * nit * Fix batch hijack testpull/2853/head
parent
1b95e476ec
commit
e33c046158
|
|
@ -94,6 +94,10 @@ jobs:
|
|||
wait-for-it --service 127.0.0.1:7860 -t 600
|
||||
python -m pytest -v --junitxml=test/results.xml --cov ./extensions/sd-webui-controlnet --cov-report=xml --verify-base-url ./extensions/sd-webui-controlnet/tests
|
||||
working-directory: stable-diffusion-webui
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
python -m pytest -v ./unit_tests/
|
||||
working-directory: stable-diffusion-webui/extensions/sd-webui-controlnet/
|
||||
- name: Kill test server
|
||||
if: always()
|
||||
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
|
||||
|
|
|
|||
|
|
@ -0,0 +1,443 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Optional, List, Annotated, ClassVar, Callable, Any, Tuple, Union
|
||||
from pydantic import BaseModel, validator, root_validator, Field
|
||||
from PIL import Image
|
||||
from logging import Logger
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
|
||||
from scripts.enums import (
|
||||
InputMode,
|
||||
ResizeMode,
|
||||
ControlMode,
|
||||
HiResFixOption,
|
||||
PuLIDMode,
|
||||
)
|
||||
|
||||
|
||||
def _unimplemented_func(*args, **kwargs):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
def field_to_displaytext(fieldname: str) -> str:
|
||||
return " ".join([word.capitalize() for word in fieldname.split("_")])
|
||||
|
||||
|
||||
def displaytext_to_field(text: str) -> str:
|
||||
return "_".join([word.lower() for word in text.split(" ")])
|
||||
|
||||
|
||||
def serialize_value(value) -> str:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
return str(value)
|
||||
|
||||
|
||||
def parse_value(value: str) -> Union[str, float, int, bool]:
|
||||
if value in ("True", "False"):
|
||||
return value == "True"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value # Plain string.
|
||||
|
||||
|
||||
class ControlNetUnit(BaseModel):
|
||||
"""
|
||||
Represents an entire ControlNet processing unit.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "ignore"
|
||||
|
||||
cls_match_module: ClassVar[Callable[[str], bool]] = _unimplemented_func
|
||||
cls_match_model: ClassVar[Callable[[str], bool]] = _unimplemented_func
|
||||
cls_decode_base64: ClassVar[Callable[[str], np.ndarray]] = _unimplemented_func
|
||||
cls_torch_load_base64: ClassVar[Callable[[Any], torch.Tensor]] = _unimplemented_func
|
||||
cls_get_preprocessor: ClassVar[Callable[[str], Any]] = _unimplemented_func
|
||||
cls_logger: ClassVar[Logger] = Logger("ControlNetUnit")
|
||||
|
||||
# UI only fields.
|
||||
is_ui: bool = False
|
||||
input_mode: InputMode = InputMode.SIMPLE
|
||||
batch_images: Optional[Any] = None
|
||||
output_dir: str = ""
|
||||
loopback: bool = False
|
||||
|
||||
# General fields.
|
||||
enabled: bool = False
|
||||
module: str = "none"
|
||||
|
||||
@validator("module", always=True, pre=True)
|
||||
def check_module(cls, value: str) -> str:
|
||||
if not ControlNetUnit.cls_match_module(value):
|
||||
raise ValueError(f"module({value}) not found in supported modules.")
|
||||
return value
|
||||
|
||||
model: str = "None"
|
||||
|
||||
@validator("model", always=True, pre=True)
|
||||
def check_model(cls, value: str) -> str:
|
||||
if not ControlNetUnit.cls_match_model(value):
|
||||
raise ValueError(f"model({value}) not found in supported models.")
|
||||
return value
|
||||
|
||||
weight: Annotated[float, Field(ge=0.0, le=2.0)] = 1.0
|
||||
|
||||
# The image to be used for this ControlNetUnit.
|
||||
image: Optional[Any] = None
|
||||
|
||||
resize_mode: ResizeMode = ResizeMode.INNER_FIT
|
||||
low_vram: bool = False
|
||||
processor_res: int = -1
|
||||
threshold_a: float = -1
|
||||
threshold_b: float = -1
|
||||
|
||||
@root_validator
|
||||
def bound_check_params(cls, values: dict) -> dict:
|
||||
"""
|
||||
Checks and corrects negative parameters in ControlNetUnit 'unit' in place.
|
||||
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
|
||||
their default values if negative.
|
||||
"""
|
||||
enabled = values.get("enabled")
|
||||
if not enabled:
|
||||
return values
|
||||
|
||||
module = values.get("module")
|
||||
if not module:
|
||||
return values
|
||||
|
||||
preprocessor = cls.cls_get_preprocessor(module)
|
||||
assert preprocessor is not None
|
||||
for unit_param, param in zip(
|
||||
("processor_res", "threshold_a", "threshold_b"),
|
||||
("slider_resolution", "slider_1", "slider_2"),
|
||||
):
|
||||
value = values.get(unit_param)
|
||||
cfg = getattr(preprocessor, param)
|
||||
if value < cfg.minimum or value > cfg.maximum:
|
||||
values[unit_param] = cfg.value
|
||||
# Only report warning when non-default value is used.
|
||||
if value != -1:
|
||||
cls.cls_logger.info(
|
||||
f"[{module}.{unit_param}] Invalid value({value}), using default value {cfg.value}."
|
||||
)
|
||||
return values
|
||||
|
||||
guidance_start: Annotated[float, Field(ge=0.0, le=1.0)] = 0.0
|
||||
guidance_end: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0
|
||||
|
||||
@root_validator
|
||||
def guidance_check(cls, values: dict) -> dict:
|
||||
start = values.get("guidance_start")
|
||||
end = values.get("guidance_end")
|
||||
if start > end:
|
||||
raise ValueError(f"guidance_start({start}) > guidance_end({end})")
|
||||
return values
|
||||
|
||||
pixel_perfect: bool = False
|
||||
control_mode: ControlMode = ControlMode.BALANCED
|
||||
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
|
||||
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
|
||||
inpaint_crop_input_image: bool = True
|
||||
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
|
||||
# The value is ignored if the generation is not using hires fix.
|
||||
hr_option: HiResFixOption = HiResFixOption.BOTH
|
||||
|
||||
# Whether save the detected map of this unit. Setting this option to False prevents saving the
|
||||
# detected map or sending detected map along with generated images via API.
|
||||
# Currently the option is only accessible in API calls.
|
||||
save_detected_map: bool = True
|
||||
|
||||
# Weight for each layer of ControlNet params.
|
||||
# For ControlNet:
|
||||
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
|
||||
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
|
||||
# For T2IAdapter
|
||||
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
|
||||
# - SDXL: 4 weights (3 encoder block + 1 middle block)
|
||||
# For IPAdapter
|
||||
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
|
||||
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
|
||||
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
|
||||
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
|
||||
# Note2: The field `weight` is still used in some places, e.g. reference_only,
|
||||
# even advanced_weighting is set.
|
||||
advanced_weighting: Optional[List[float]] = None
|
||||
|
||||
# The effective region mask that unit's effect should be restricted to.
|
||||
effective_region_mask: Optional[np.ndarray] = None
|
||||
|
||||
@validator("effective_region_mask", pre=True)
|
||||
def parse_effective_region_mask(cls, value) -> np.ndarray:
|
||||
if isinstance(value, str):
|
||||
return cls.cls_decode_base64(value)
|
||||
assert isinstance(value, np.ndarray) or value is None
|
||||
return value
|
||||
|
||||
# The weight mode for PuLID.
|
||||
# https://github.com/ToTheBeginning/PuLID
|
||||
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
|
||||
|
||||
# ------- API only fields -------
|
||||
# The tensor input for ipadapter. When this field is set in the API,
|
||||
# the base64string will be interpret by torch.load to reconstruct ipadapter
|
||||
# preprocessor output.
|
||||
# Currently the option is only accessible in API calls.
|
||||
ipadapter_input: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@validator("ipadapter_input", pre=True)
|
||||
def parse_ipadapter_input(cls, value) -> Optional[List[torch.Tensor]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
result = [cls.cls_torch_load_base64(b) for b in value]
|
||||
assert result, "input cannot be empty"
|
||||
return result
|
||||
|
||||
# The mask to be used on top of the image.
|
||||
mask: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def accepts_multiple_inputs(self) -> bool:
|
||||
"""This unit can accept multiple input images."""
|
||||
return self.module in (
|
||||
"ip-adapter-auto",
|
||||
"ip-adapter_clip_sdxl",
|
||||
"ip-adapter_clip_sdxl_plus_vith",
|
||||
"ip-adapter_clip_sd15",
|
||||
"ip-adapter_face_id",
|
||||
"ip-adapter_face_id_plus",
|
||||
"ip-adapter_pulid",
|
||||
"instant_id_face_embedding",
|
||||
)
|
||||
|
||||
@property
|
||||
def is_animate_diff_batch(self) -> bool:
|
||||
return getattr(self, "animatediff_batch", False)
|
||||
|
||||
@property
|
||||
def uses_clip(self) -> bool:
|
||||
"""Whether this unit uses clip preprocessor."""
|
||||
return any(
|
||||
(
|
||||
("ip-adapter" in self.module and "face_id" not in self.module),
|
||||
self.module
|
||||
in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_inpaint(self) -> bool:
|
||||
return "inpaint" in self.module
|
||||
|
||||
def get_actual_preprocessor(self):
|
||||
if self.module == "ip-adapter-auto":
|
||||
return ControlNetUnit.cls_get_preprocessor(
|
||||
self.module
|
||||
).get_preprocessor_by_model(self.model)
|
||||
return ControlNetUnit.cls_get_preprocessor(self.module)
|
||||
|
||||
@classmethod
|
||||
def parse_image(cls, image) -> np.ndarray:
|
||||
if isinstance(image, np.ndarray):
|
||||
np_image = image
|
||||
elif isinstance(image, str):
|
||||
# Necessary for batch.
|
||||
if os.path.exists(image):
|
||||
np_image = np.array(Image.open(image)).astype("uint8")
|
||||
else:
|
||||
np_image = cls.cls_decode_base64(image)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized image format {image}.")
|
||||
|
||||
# [H, W] => [H, W, 3]
|
||||
if np_image.ndim == 2:
|
||||
np_image = np.stack([np_image, np_image, np_image], axis=-1)
|
||||
assert np_image.ndim == 3
|
||||
assert np_image.shape[2] == 3
|
||||
return np_image
|
||||
|
||||
@classmethod
|
||||
def combine_image_and_mask(
|
||||
cls, np_image: np.ndarray, np_mask: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""RGB + Alpha(Optional) => RGBA"""
|
||||
# TODO: Change protocol to use 255 as A channel value.
|
||||
# Note: mask is by default zeros, as both inpaint and
|
||||
# clip mask does extra work on masked area.
|
||||
np_mask = (np.zeros_like(np_image) if np_mask is None else np_mask)[:, :, 0:1]
|
||||
if np_image.shape[:2] != np_mask.shape[:2]:
|
||||
raise ValueError(
|
||||
f"image shape ({np_image.shape[:2]}) not aligned with mask shape ({np_mask.shape[:2]})"
|
||||
)
|
||||
return np.concatenate([np_image, np_mask], axis=2) # [H, W, 4]
|
||||
|
||||
@classmethod
|
||||
def legacy_field_alias(cls, values: dict) -> dict:
|
||||
ext_compat_keys = {
|
||||
"guidance": "guidance_end",
|
||||
"lowvram": "low_vram",
|
||||
"input_image": "image",
|
||||
}
|
||||
for alias, key in ext_compat_keys.items():
|
||||
if alias in values:
|
||||
assert key not in values, f"Conflict of field '{alias}' and '{key}'"
|
||||
values[key] = alias
|
||||
cls.cls_logger.warn(
|
||||
f"Deprecated alias '{alias}' detected. This field will be removed on 2024-06-01"
|
||||
f"Please use '{key}' instead."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def mask_alias(cls, values: dict) -> dict:
|
||||
"""
|
||||
Field "mask_image" is the alias of field "mask".
|
||||
This is for compatibility with SD Forge API.
|
||||
"""
|
||||
mask_image = values.get("mask_image")
|
||||
mask = values.get("mask")
|
||||
if mask_image is not None:
|
||||
if mask is not None:
|
||||
raise ValueError("Cannot specify both 'mask' and 'mask_image'!")
|
||||
values["mask"] = mask_image
|
||||
return values
|
||||
|
||||
def get_input_images_rgba(self) -> Optional[List[np.ndarray]]:
|
||||
"""
|
||||
RGBA images with potentially different size.
|
||||
Why we cannot have [B, H, W, C=4] here is that calculation of final
|
||||
resolution requires generation target's dimensions.
|
||||
|
||||
Parse image with following formats.
|
||||
API
|
||||
- image = {"image": base64image, "mask": base64image,}
|
||||
- image = [image, mask]
|
||||
- image = (image, mask)
|
||||
- image = [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...]
|
||||
- image = base64image, mask = base64image
|
||||
|
||||
UI:
|
||||
- image = {"image": np_image, "mask": np_image,}
|
||||
- image = np_image, mask = np_image
|
||||
"""
|
||||
init_image = self.image
|
||||
init_mask = self.mask
|
||||
|
||||
if init_image is None:
|
||||
assert init_mask is None
|
||||
return None
|
||||
|
||||
if isinstance(init_image, (list, tuple)):
|
||||
if not init_image:
|
||||
raise ValueError(f"{init_image} is not a valid 'image' field value")
|
||||
if isinstance(init_image[0], dict):
|
||||
# [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...]
|
||||
images = init_image
|
||||
else:
|
||||
assert len(init_image) == 2
|
||||
# [image, mask]
|
||||
# (image, mask)
|
||||
images = [
|
||||
{
|
||||
"image": init_image[0],
|
||||
"mask": init_image[1],
|
||||
}
|
||||
]
|
||||
elif isinstance(init_image, dict):
|
||||
# {"image": ..., "mask": ...}
|
||||
images = [init_image]
|
||||
elif isinstance(init_image, (str, np.ndarray)):
|
||||
# image = base64image, mask = base64image
|
||||
images = [
|
||||
{
|
||||
"image": init_image,
|
||||
"mask": init_mask,
|
||||
}
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unrecognized image field {init_image}")
|
||||
|
||||
np_images = []
|
||||
for image_dict in images:
|
||||
assert isinstance(image_dict, dict)
|
||||
image = image_dict.get("image")
|
||||
mask = image_dict.get("mask")
|
||||
assert image is not None
|
||||
|
||||
np_image = self.parse_image(image)
|
||||
np_mask = self.parse_image(mask) if mask is not None else None
|
||||
np_images.append(self.combine_image_and_mask(np_image, np_mask)) # [H, W, 4]
|
||||
|
||||
return np_images
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, values: dict) -> ControlNetUnit:
|
||||
values = copy(values)
|
||||
values = cls.legacy_field_alias(values)
|
||||
values = cls.mask_alias(values)
|
||||
return ControlNetUnit(**values)
|
||||
|
||||
@classmethod
|
||||
def from_infotext_args(cls, *args) -> ControlNetUnit:
|
||||
assert len(args) == len(ControlNetUnit.infotext_fields())
|
||||
return cls.from_dict(
|
||||
{k: v for k, v in zip(ControlNetUnit.infotext_fields(), args)}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def infotext_fields() -> Tuple[str]:
|
||||
"""Fields that should be included in infotext.
|
||||
You should define a Gradio element with exact same name in ControlNetUiGroup
|
||||
as well, so that infotext can wire the value to correct field when pasting
|
||||
infotext.
|
||||
"""
|
||||
return (
|
||||
"module",
|
||||
"model",
|
||||
"weight",
|
||||
"resize_mode",
|
||||
"processor_res",
|
||||
"threshold_a",
|
||||
"threshold_b",
|
||||
"guidance_start",
|
||||
"guidance_end",
|
||||
"pixel_perfect",
|
||||
"control_mode",
|
||||
)
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""Serialize the unit for infotext."""
|
||||
infotext_dict = {
|
||||
field_to_displaytext(field): serialize_value(getattr(self, field))
|
||||
for field in ControlNetUnit.infotext_fields()
|
||||
}
|
||||
if not all(
|
||||
"," not in str(v) and ":" not in str(v) for v in infotext_dict.values()
|
||||
):
|
||||
self.cls_logger.error(f"Unexpected tokens encountered:\n{infotext_dict}")
|
||||
return ""
|
||||
|
||||
return ", ".join(f"{field}: {value}" for field, value in infotext_dict.items())
|
||||
|
||||
@classmethod
|
||||
def parse(cls, text: str) -> ControlNetUnit:
|
||||
return ControlNetUnit(
|
||||
enabled=True,
|
||||
**{
|
||||
displaytext_to_field(key): parse_value(value)
|
||||
for item in text.split(",")
|
||||
for (key, value) in (item.strip().split(": "),)
|
||||
},
|
||||
)
|
||||
|
|
@ -1,23 +1,30 @@
|
|||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from copy import copy
|
||||
from typing import List, Any, Optional, Union, Tuple, Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modules import scripts, processing, shared
|
||||
from modules.safe import unsafe_torch_load
|
||||
from modules.api import api
|
||||
from .args import ControlNetUnit
|
||||
from scripts import global_state
|
||||
from scripts.logging import logger
|
||||
from scripts.enums import HiResFixOption, PuLIDMode, ControlMode, ResizeMode
|
||||
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
from scripts.enums import (
|
||||
ResizeMode,
|
||||
BatchOption, # noqa: F401
|
||||
ControlMode, # noqa: F401
|
||||
)
|
||||
from scripts.supported_preprocessor import (
|
||||
Preprocessor,
|
||||
PreprocessorParameter, # noqa: F401
|
||||
)
|
||||
|
||||
from modules.api import api
|
||||
import torch
|
||||
import base64
|
||||
import io
|
||||
from modules.safe import unsafe_torch_load
|
||||
|
||||
|
||||
def get_api_version() -> int:
|
||||
return 2
|
||||
return 3
|
||||
|
||||
|
||||
resize_mode_aliases = {
|
||||
|
|
@ -47,15 +54,6 @@ def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
|
|||
return value
|
||||
|
||||
|
||||
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode:
|
||||
if isinstance(value, str):
|
||||
return ControlMode(value)
|
||||
elif isinstance(value, int):
|
||||
return [e for e in ControlMode][value]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def visualize_inpaint_mask(img):
|
||||
if img.ndim == 3 and img.shape[2] == 4:
|
||||
result = img.copy()
|
||||
|
|
@ -117,153 +115,7 @@ def pixel_perfect_resolution(
|
|||
return int(np.round(estimation))
|
||||
|
||||
|
||||
InputImage = Union[np.ndarray, str]
|
||||
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetUnit:
|
||||
"""
|
||||
Represents an entire ControlNet processing unit.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
module: str = "none"
|
||||
model: str = "None"
|
||||
weight: float = 1.0
|
||||
image: Optional[Union[InputImage, List[InputImage]]] = None
|
||||
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
|
||||
low_vram: bool = False
|
||||
processor_res: int = -1
|
||||
threshold_a: float = -1
|
||||
threshold_b: float = -1
|
||||
guidance_start: float = 0.0
|
||||
guidance_end: float = 1.0
|
||||
pixel_perfect: bool = False
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
|
||||
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
|
||||
inpaint_crop_input_image: bool = True
|
||||
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
|
||||
# The value is ignored if the generation is not using hires fix.
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
|
||||
# Whether save the detected map of this unit. Setting this option to False prevents saving the
|
||||
# detected map or sending detected map along with generated images via API.
|
||||
# Currently the option is only accessible in API calls.
|
||||
save_detected_map: bool = True
|
||||
|
||||
# Weight for each layer of ControlNet params.
|
||||
# For ControlNet:
|
||||
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
|
||||
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
|
||||
# For T2IAdapter
|
||||
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
|
||||
# - SDXL: 4 weights (3 encoder block + 1 middle block)
|
||||
# For IPAdapter
|
||||
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
|
||||
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
|
||||
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
|
||||
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
|
||||
# Note2: The field `weight` is still used in some places, e.g. reference_only,
|
||||
# even advanced_weighting is set.
|
||||
advanced_weighting: Optional[List[float]] = None
|
||||
|
||||
# The effective region mask that unit's effect should be restricted to.
|
||||
effective_region_mask: Optional[np.ndarray] = None
|
||||
|
||||
# The weight mode for PuLID.
|
||||
# https://github.com/ToTheBeginning/PuLID
|
||||
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
|
||||
|
||||
# The tensor input for ipadapter. When this field is set in the API,
|
||||
# the base64string will be interpret by torch.load to reconstruct ipadapter
|
||||
# preprocessor output.
|
||||
# Currently the option is only accessible in API calls.
|
||||
ipadapter_input: Optional[List[Any]] = None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ControlNetUnit):
|
||||
return False
|
||||
|
||||
return vars(self) == vars(other)
|
||||
|
||||
def accepts_multiple_inputs(self) -> bool:
|
||||
"""This unit can accept multiple input images."""
|
||||
return self.module in (
|
||||
"ip-adapter-auto",
|
||||
"ip-adapter_clip_sdxl",
|
||||
"ip-adapter_clip_sdxl_plus_vith",
|
||||
"ip-adapter_clip_sd15",
|
||||
"ip-adapter_face_id",
|
||||
"ip-adapter_face_id_plus",
|
||||
"ip-adapter_pulid",
|
||||
"instant_id_face_embedding",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def infotext_excluded_fields() -> List[str]:
|
||||
return [
|
||||
"image",
|
||||
"enabled",
|
||||
# API-only fields.
|
||||
"advanced_weighting",
|
||||
"ipadapter_input",
|
||||
# End of API-only fields.
|
||||
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
|
||||
# provide much information when restoring the unit.
|
||||
"inpaint_crop_input_image",
|
||||
"effective_region_mask",
|
||||
"pulid_mode",
|
||||
]
|
||||
|
||||
@property
|
||||
def is_animate_diff_batch(self) -> bool:
|
||||
return getattr(self, "animatediff_batch", False)
|
||||
|
||||
@property
|
||||
def uses_clip(self) -> bool:
|
||||
"""Whether this unit uses clip preprocessor."""
|
||||
return any(
|
||||
(
|
||||
("ip-adapter" in self.module and "face_id" not in self.module),
|
||||
self.module
|
||||
in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_inpaint(self) -> bool:
|
||||
return "inpaint" in self.module
|
||||
|
||||
def bound_check_params(self) -> None:
|
||||
"""
|
||||
Checks and corrects negative parameters in ControlNetUnit 'unit' in place.
|
||||
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
|
||||
their default values if negative.
|
||||
"""
|
||||
preprocessor = Preprocessor.get_preprocessor(self.module)
|
||||
for unit_param, param in zip(
|
||||
("processor_res", "threshold_a", "threshold_b"),
|
||||
("slider_resolution", "slider_1", "slider_2"),
|
||||
):
|
||||
value = getattr(self, unit_param)
|
||||
cfg: PreprocessorParameter = getattr(preprocessor, param)
|
||||
if value < 0:
|
||||
setattr(self, unit_param, cfg.value)
|
||||
logger.info(
|
||||
f"[{self.module}.{unit_param}] Invalid value({value}), using default value {cfg.value}."
|
||||
)
|
||||
|
||||
def get_actual_preprocessor(self) -> Preprocessor:
|
||||
if self.module == "ip-adapter-auto":
|
||||
return Preprocessor.get_preprocessor(self.module).get_preprocessor_by_model(
|
||||
self.model
|
||||
)
|
||||
return Preprocessor.get_preprocessor(self.module)
|
||||
|
||||
|
||||
def to_base64_nparray(encoding: str):
|
||||
def to_base64_nparray(encoding: str) -> np.ndarray:
|
||||
"""
|
||||
Convert a base64 image into the image type the extension uses
|
||||
"""
|
||||
|
|
@ -368,73 +220,14 @@ def get_max_models_num():
|
|||
return max_models_num
|
||||
|
||||
|
||||
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit:
|
||||
def to_processing_unit(unit: Union[Dict, ControlNetUnit]) -> ControlNetUnit:
|
||||
"""
|
||||
Convert different types to processing unit.
|
||||
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details.
|
||||
"""
|
||||
|
||||
ext_compat_keys = {
|
||||
"guessmode": "guess_mode",
|
||||
"guidance": "guidance_end",
|
||||
"lowvram": "low_vram",
|
||||
"input_image": "image",
|
||||
}
|
||||
|
||||
if isinstance(unit, dict):
|
||||
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
|
||||
return ControlNetUnit.from_dict(unit)
|
||||
|
||||
# Handle mask
|
||||
mask = None
|
||||
if "mask" in unit:
|
||||
mask = unit["mask"]
|
||||
del unit["mask"]
|
||||
|
||||
if "mask_image" in unit:
|
||||
mask = unit["mask_image"]
|
||||
del unit["mask_image"]
|
||||
|
||||
if "image" in unit and not isinstance(unit["image"], dict):
|
||||
unit["image"] = (
|
||||
{"image": unit["image"], "mask": mask}
|
||||
if mask is not None
|
||||
else unit["image"] if unit["image"] else None
|
||||
)
|
||||
|
||||
# Parse ipadapter_input
|
||||
if "ipadapter_input" in unit and unit["ipadapter_input"] is not None:
|
||||
|
||||
def decode_base64(b: str) -> torch.Tensor:
|
||||
decoded_bytes = base64.b64decode(b)
|
||||
return unsafe_torch_load(io.BytesIO(decoded_bytes))
|
||||
|
||||
if isinstance(unit["ipadapter_input"], str):
|
||||
unit["ipadapter_input"] = [unit["ipadapter_input"]]
|
||||
|
||||
unit["ipadapter_input"] = [
|
||||
decode_base64(b) for b in unit["ipadapter_input"]
|
||||
]
|
||||
|
||||
if unit.get("effective_region_mask", None) is not None:
|
||||
base64img = unit["effective_region_mask"]
|
||||
assert isinstance(base64img, str)
|
||||
unit["effective_region_mask"] = to_base64_nparray(base64img)
|
||||
|
||||
if "guess_mode" in unit:
|
||||
logger.warning(
|
||||
"Guess Mode is removed since 1.1.136. Please use Control Mode instead."
|
||||
)
|
||||
|
||||
for k in unit.keys():
|
||||
if k not in vars(ControlNetUnit):
|
||||
logger.warn(f"Received unrecognized key '{k}' in API.")
|
||||
|
||||
unit = ControlNetUnit(
|
||||
**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()}
|
||||
)
|
||||
|
||||
# temporary, check #602
|
||||
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]'
|
||||
assert isinstance(unit, ControlNetUnit)
|
||||
return unit
|
||||
|
||||
|
||||
|
|
@ -621,3 +414,23 @@ def is_cn_script(script: scripts.Script) -> bool:
|
|||
"""
|
||||
|
||||
return script.title().lower() == "controlnet"
|
||||
|
||||
|
||||
# TODO: Add model constraint
|
||||
ControlNetUnit.cls_match_model = lambda model: True
|
||||
ControlNetUnit.cls_match_module = (
|
||||
lambda module: Preprocessor.get_preprocessor(module) is not None
|
||||
)
|
||||
ControlNetUnit.cls_get_preprocessor = Preprocessor.get_preprocessor
|
||||
ControlNetUnit.cls_decode_base64 = to_base64_nparray
|
||||
|
||||
|
||||
def decode_base64(b: str) -> torch.Tensor:
|
||||
decoded_bytes = base64.b64decode(b)
|
||||
return unsafe_torch_load(io.BytesIO(decoded_bytes))
|
||||
|
||||
|
||||
ControlNetUnit.cls_torch_load_base64 = decode_base64
|
||||
ControlNetUnit.cls_logger = logger
|
||||
|
||||
logger.debug("ControlNetUnit initialized")
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@
|
|||
this.attachImageUploadListener();
|
||||
this.attachImageStateChangeObserver();
|
||||
this.attachA1111SendInfoObserver();
|
||||
this.attachPresetDropdownObserver();
|
||||
}
|
||||
|
||||
getTabNavButton() {
|
||||
|
|
@ -303,26 +302,6 @@
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
attachPresetDropdownObserver() {
|
||||
const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown');
|
||||
|
||||
new MutationObserver((mutationsList) => {
|
||||
for (const mutation of mutationsList) {
|
||||
if (mutation.removedNodes.length > 0) {
|
||||
setTimeout(() => {
|
||||
this.updateActiveState();
|
||||
this.updateActiveUnitCount();
|
||||
this.updateActiveControlType();
|
||||
}, 1000);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}).observe(presetDropDown, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
gradioApp().querySelectorAll('#controlnet').forEach(accordion => {
|
||||
|
|
|
|||
|
|
@ -137,12 +137,12 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|||
)
|
||||
|
||||
unit = ControlNetUnit(
|
||||
enabled=True,
|
||||
module=preprocessor.label,
|
||||
processor_res=controlnet_processor_res,
|
||||
threshold_a=controlnet_threshold_a,
|
||||
threshold_b=controlnet_threshold_b,
|
||||
)
|
||||
unit.bound_check_params()
|
||||
|
||||
tensors = []
|
||||
images = []
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import os
|
||||
from copy import copy
|
||||
from typing import Tuple, List
|
||||
|
||||
from modules import img2img, processing, shared, script_callbacks
|
||||
from scripts import external_code
|
||||
from scripts.enums import InputMode
|
||||
from scripts.logging import logger
|
||||
|
||||
class BatchHijack:
|
||||
def __init__(self):
|
||||
|
|
@ -194,7 +194,7 @@ def unhijack_function(module, name, new_name):
|
|||
|
||||
def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]:
|
||||
units = external_code.get_all_units_in_processing(p)
|
||||
units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)]
|
||||
units = [unit.copy() for unit in units if getattr(unit, 'enabled', False)]
|
||||
any_unit_is_batch = False
|
||||
output_dir = ''
|
||||
input_file_names = []
|
||||
|
|
@ -222,6 +222,8 @@ def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[
|
|||
else:
|
||||
batches[i].append(unit.image)
|
||||
|
||||
if any_unit_is_batch:
|
||||
logger.info(f"Batch enabled ({len(batches)})")
|
||||
return any_unit_is_batch, batches, output_dir, input_file_names
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@ import os
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import copy, deepcopy
|
||||
from typing import Dict, Optional, Tuple, List, Union
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import modules.scripts as scripts
|
||||
from modules import shared, devices, script_callbacks, processing, masking, images
|
||||
from modules.api.api import decode_base64_to_image
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
|
|
@ -26,6 +25,7 @@ from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
|
|||
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
|
||||
from scripts.enums import (
|
||||
ControlModelType,
|
||||
InputMode,
|
||||
StableDiffusionVersion,
|
||||
HiResFixOption,
|
||||
PuLIDMode,
|
||||
|
|
@ -33,7 +33,7 @@ from scripts.enums import (
|
|||
BatchOption,
|
||||
ResizeMode,
|
||||
)
|
||||
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
|
||||
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
|
||||
from scripts.controlnet_ui.photopea import Photopea
|
||||
from scripts.logging import logger
|
||||
from scripts.supported_preprocessor import Preprocessor
|
||||
|
|
@ -101,44 +101,6 @@ def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img):
|
|||
global_state.update_cn_models()
|
||||
|
||||
|
||||
def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
if isinstance(image, (tuple, list)):
|
||||
image = {'image': image[0], 'mask': image[1]}
|
||||
elif not isinstance(image, dict):
|
||||
image = {'image': image, 'mask': None}
|
||||
else: # type(image) is dict
|
||||
# copy to enable modifying the dict and prevent response serialization error
|
||||
image = dict(image)
|
||||
|
||||
if isinstance(image['image'], str):
|
||||
if os.path.exists(image['image']):
|
||||
image['image'] = np.array(Image.open(image['image'])).astype('uint8')
|
||||
elif image['image']:
|
||||
image['image'] = external_code.to_base64_nparray(image['image'])
|
||||
else:
|
||||
image['image'] = None
|
||||
|
||||
# If there is no image, return image with None image and None mask
|
||||
if image['image'] is None:
|
||||
image['mask'] = None
|
||||
return image
|
||||
|
||||
if 'mask' not in image or image['mask'] is None:
|
||||
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
||||
elif isinstance(image['mask'], str):
|
||||
if os.path.exists(image['mask']):
|
||||
image['mask'] = np.array(Image.open(image['mask']).convert("RGB")).astype('uint8')
|
||||
elif image['mask']:
|
||||
image['mask'] = external_code.to_base64_nparray(image['mask'])
|
||||
else:
|
||||
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def prepare_mask(
|
||||
mask: Image.Image, p: processing.StableDiffusionProcessing
|
||||
) -> Image.Image:
|
||||
|
|
@ -242,7 +204,7 @@ def get_control(
|
|||
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
|
||||
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
|
||||
if isinstance(input_image, list):
|
||||
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
|
||||
assert unit.accepts_multiple_inputs or unit.is_animate_diff_batch
|
||||
input_images = input_image
|
||||
else: # Following operations are only for single input image.
|
||||
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
|
||||
|
|
@ -355,21 +317,8 @@ class Script(scripts.Script, metaclass=(
|
|||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
@staticmethod
|
||||
def get_default_ui_unit(is_ui=True):
|
||||
cls = UiControlNetUnit if is_ui else ControlNetUnit
|
||||
return cls(
|
||||
enabled=False,
|
||||
module="none",
|
||||
model="None"
|
||||
)
|
||||
|
||||
def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]:
|
||||
group = ControlNetUiGroup(
|
||||
is_img2img,
|
||||
Script.get_default_ui_unit(),
|
||||
photopea,
|
||||
)
|
||||
group = ControlNetUiGroup(is_img2img, photopea)
|
||||
return group, group.render(tabname, elem_id_tabname)
|
||||
|
||||
def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str):
|
||||
|
|
@ -665,10 +614,31 @@ class Script(scripts.Script, metaclass=(
|
|||
|
||||
@staticmethod
|
||||
def get_enabled_units(p):
|
||||
def unfold_merged(unit: ControlNetUnit) -> List[ControlNetUnit]:
|
||||
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
|
||||
preprocessors that can accept multiple input images.
|
||||
"""
|
||||
if unit.input_mode != InputMode.MERGE:
|
||||
return [unit]
|
||||
|
||||
if unit.accepts_multiple_inputs:
|
||||
unit.input_mode = InputMode.SIMPLE
|
||||
return [unit]
|
||||
|
||||
assert isinstance(unit.image, list)
|
||||
result = []
|
||||
for image in unit.image:
|
||||
u = unit.copy()
|
||||
u.image = [image]
|
||||
u.input_mode = InputMode.SIMPLE
|
||||
u.weight = unit.weight / len(unit.image)
|
||||
result.append(u)
|
||||
return result
|
||||
|
||||
units = external_code.get_all_units_in_processing(p)
|
||||
if len(units) == 0:
|
||||
# fill a null group
|
||||
remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0)
|
||||
remote_unit = Script.parse_remote_call(p, ControlNetUnit(), 0)
|
||||
if remote_unit.enabled:
|
||||
units.append(remote_unit)
|
||||
|
||||
|
|
@ -677,11 +647,7 @@ class Script(scripts.Script, metaclass=(
|
|||
local_unit = Script.parse_remote_call(p, unit, idx)
|
||||
if not local_unit.enabled:
|
||||
continue
|
||||
|
||||
if hasattr(local_unit, "unfold_merged"):
|
||||
enabled_units.extend(local_unit.unfold_merged())
|
||||
else:
|
||||
enabled_units.append(copy(local_unit))
|
||||
enabled_units.extend(unfold_merged(local_unit))
|
||||
|
||||
Infotext.write_infotext(enabled_units, p)
|
||||
return enabled_units
|
||||
|
|
@ -695,43 +661,31 @@ class Script(scripts.Script, metaclass=(
|
|||
""" Choose input image from following sources with descending priority:
|
||||
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
|
||||
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
|
||||
- unit.image: ControlNet tab input image.
|
||||
- p.init_images: A1111 img2img tab input image.
|
||||
- unit.image: ControlNet unit input image.
|
||||
- p.init_images: A1111 img2img input image.
|
||||
|
||||
Returns:
|
||||
- The input image in ndarray form.
|
||||
- The resize mode.
|
||||
"""
|
||||
def parse_unit_image(unit: ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
|
||||
unit_has_multiple_images = (
|
||||
isinstance(unit.image, list) and
|
||||
len(unit.image) > 0 and
|
||||
"image" in unit.image[0]
|
||||
)
|
||||
if unit_has_multiple_images:
|
||||
return [
|
||||
d
|
||||
for img in unit.image
|
||||
for d in (image_dict_from_any(img),)
|
||||
if d is not None
|
||||
]
|
||||
return image_dict_from_any(unit.image)
|
||||
|
||||
def decode_image(img) -> np.ndarray:
|
||||
"""Need to check the image for API compatibility."""
|
||||
if isinstance(img, str):
|
||||
return np.asarray(decode_base64_to_image(image['image']))
|
||||
else:
|
||||
assert isinstance(img, np.ndarray)
|
||||
return img
|
||||
def from_rgba_to_input(img: np.ndarray) -> np.ndarray:
|
||||
if (
|
||||
shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) or
|
||||
(img[:, :, 3] <= 5).all() or
|
||||
(img[:, :, 3] >= 250).all()
|
||||
):
|
||||
# Take RGB
|
||||
return img[:, :, :3]
|
||||
logger.info("Canvas scribble mode. Using mask scribble as input.")
|
||||
return HWC3(img[:, :, 3])
|
||||
|
||||
# 4 input image sources.
|
||||
p_image_control = getattr(p, "image_control", None)
|
||||
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
|
||||
image = parse_unit_image(unit)
|
||||
image = unit.get_input_images_rgba()
|
||||
a1111_image = getattr(p, "init_images", [None])[0]
|
||||
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
resize_mode = unit.resize_mode
|
||||
|
||||
if batch_hijack.instance.is_batch and p_image_control is not None:
|
||||
logger.warning("Warn: Using legacy field 'p.image_control'.")
|
||||
|
|
@ -744,42 +698,18 @@ class Script(scripts.Script, metaclass=(
|
|||
input_image = np.concatenate([color, alpha], axis=2)
|
||||
else:
|
||||
input_image = HWC3(np.asarray(p_input_image))
|
||||
elif image:
|
||||
if isinstance(image, list):
|
||||
# Add mask logic if later there is a processor that accepts mask
|
||||
# on multiple inputs.
|
||||
input_image = [HWC3(decode_image(img['image'])) for img in image]
|
||||
if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None:
|
||||
for idx in range(len(input_image)):
|
||||
while len(image[idx]['mask'].shape) < 3:
|
||||
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
|
||||
if unit.is_inpaint or unit.uses_clip:
|
||||
color = HWC3(image[idx]["image"])
|
||||
alpha = image[idx]['mask'][:, :, 0:1]
|
||||
input_image[idx] = np.concatenate([color, alpha], axis=2)
|
||||
elif image is not None:
|
||||
assert isinstance(image, list)
|
||||
# Inpaint mask or CLIP mask.
|
||||
if unit.is_inpaint or unit.uses_clip:
|
||||
# RGBA
|
||||
input_image = image
|
||||
else:
|
||||
input_image = HWC3(decode_image(image['image']))
|
||||
if 'mask' in image and image['mask'] is not None:
|
||||
while len(image['mask'].shape) < 3:
|
||||
image['mask'] = image['mask'][..., np.newaxis]
|
||||
if unit.is_inpaint or unit.uses_clip:
|
||||
logger.info("using mask")
|
||||
color = HWC3(image['image'])
|
||||
alpha = image['mask'][:, :, 0:1]
|
||||
input_image = np.concatenate([color, alpha], axis=2)
|
||||
elif (
|
||||
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
|
||||
# There is wield gradio issue that would produce mask that is
|
||||
# not pure color when no scribble is made on canvas.
|
||||
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
|
||||
not (
|
||||
(image['mask'][:, :, 0] <= 5).all() or
|
||||
(image['mask'][:, :, 0] >= 250).all()
|
||||
)
|
||||
):
|
||||
logger.info("using mask as input")
|
||||
input_image = HWC3(image['mask'][:, :, 0])
|
||||
unit.module = 'none' # Always use black bg and white line
|
||||
# RGB
|
||||
input_image = [from_rgba_to_input(img) for img in image]
|
||||
|
||||
if len(input_image) == 1:
|
||||
input_image = input_image[0]
|
||||
elif a1111_image is not None:
|
||||
input_image = HWC3(np.asarray(a1111_image))
|
||||
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
|
||||
|
|
@ -957,7 +887,6 @@ class Script(scripts.Script, metaclass=(
|
|||
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
|
||||
|
||||
for idx, unit in enumerate(self.enabled_units):
|
||||
unit.bound_check_params()
|
||||
Script.check_sd_version_compatible(unit)
|
||||
if (
|
||||
'inpaint_only' == unit.module and
|
||||
|
|
@ -1011,7 +940,7 @@ class Script(scripts.Script, metaclass=(
|
|||
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
|
||||
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
|
||||
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
|
||||
if unit.accepts_multiple_inputs():
|
||||
if unit.accepts_multiple_inputs:
|
||||
ip_adapter_image_emb_cond = []
|
||||
model_net.ipadapter.image_proj_model.to(torch.float32) # noqa
|
||||
for c in cc:
|
||||
|
|
@ -1038,7 +967,7 @@ class Script(scripts.Script, metaclass=(
|
|||
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
|
||||
logger.info(f"\t{frame_idx}: {frame_path}")
|
||||
c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu()
|
||||
elif unit.accepts_multiple_inputs():
|
||||
elif unit.accepts_multiple_inputs:
|
||||
# ip-adapter should do prompt travel
|
||||
logger.info("IP-Adapter: control prompts will be traveled in the following way:")
|
||||
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
|
||||
|
|
@ -1067,7 +996,7 @@ class Script(scripts.Script, metaclass=(
|
|||
c_full[cn_ad_keyframe_idx] = c
|
||||
c = c_full
|
||||
# handle batch condition and unconditional
|
||||
if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs():
|
||||
if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs:
|
||||
c = torch.cat([c, c], dim=0)
|
||||
return c
|
||||
|
||||
|
|
@ -1090,7 +1019,6 @@ class Script(scripts.Script, metaclass=(
|
|||
control_model_type.is_controlnet and
|
||||
model_net.control_model.global_average_pooling
|
||||
)
|
||||
control_mode = external_code.control_mode_from_value(unit.control_mode)
|
||||
forward_param = ControlParams(
|
||||
control_model=model_net,
|
||||
preprocessor=preprocessor_dict,
|
||||
|
|
@ -1103,9 +1031,9 @@ class Script(scripts.Script, metaclass=(
|
|||
control_model_type=control_model_type,
|
||||
global_average_pooling=global_average_pooling,
|
||||
hr_hint_cond=hr_control,
|
||||
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
|
||||
soft_injection=control_mode != ControlMode.BALANCED,
|
||||
cfg_injection=control_mode == ControlMode.CONTROL,
|
||||
hr_option=unit.hr_option if high_res_fix else HiResFixOption.BOTH,
|
||||
soft_injection=unit.control_mode != ControlMode.BALANCED,
|
||||
cfg_injection=unit.control_mode == ControlMode.CONTROL,
|
||||
effective_region_mask=(
|
||||
get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :]
|
||||
if unit.effective_region_mask is not None
|
||||
|
|
@ -1217,16 +1145,10 @@ class Script(scripts.Script, metaclass=(
|
|||
weight = param.weight
|
||||
|
||||
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
|
||||
# TODO: Fix all enum issue
|
||||
if unit.pulid_mode == "PuLIDMode.FIDELITY":
|
||||
pulid_mode = PuLIDMode.FIDELITY
|
||||
else:
|
||||
pulid_mode = PuLIDMode(unit.pulid_mode)
|
||||
|
||||
if pulid_mode == PuLIDMode.STYLE:
|
||||
if unit.pulid_mode == PuLIDMode.STYLE:
|
||||
pulid_attn_setting = PULID_SETTING_STYLE
|
||||
else:
|
||||
assert pulid_mode == PuLIDMode.FIDELITY
|
||||
assert unit.pulid_mode == PuLIDMode.FIDELITY
|
||||
pulid_attn_setting = PULID_SETTING_FIDELITY
|
||||
|
||||
param.control_model.hook(
|
||||
|
|
@ -1377,7 +1299,7 @@ class Script(scripts.Script, metaclass=(
|
|||
unit.batch_images = iter([batch[unit_i] for batch in batches])
|
||||
|
||||
def batch_tab_process_each(self, p, *args, **kwargs):
|
||||
for unit_i, unit in enumerate(self.enabled_units):
|
||||
for unit in self.enabled_units:
|
||||
if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
import gradio as gr
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||
import itertools
|
||||
from typing import List, Optional, Union, Dict, Tuple, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -16,7 +16,6 @@ from annotator.util import HWC3
|
|||
from internal_controlnet.external_code import ControlNetUnit
|
||||
from scripts.logging import logger
|
||||
from scripts.controlnet_ui.openpose_editor import OpenposeEditor
|
||||
from scripts.controlnet_ui.preset import ControlNetPresetUI
|
||||
from scripts.controlnet_ui.photopea import Photopea
|
||||
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
|
||||
from scripts.enums import (
|
||||
|
|
@ -128,66 +127,39 @@ class A1111Context:
|
|||
)
|
||||
|
||||
|
||||
class UiControlNetUnit(ControlNetUnit):
|
||||
"""The data class that stores all states of a ControlNetUnit."""
|
||||
def create_ui_unit(
|
||||
input_mode: InputMode = InputMode.SIMPLE,
|
||||
batch_images: Optional[Any] = None,
|
||||
output_dir: str = "",
|
||||
loopback: bool = False,
|
||||
merge_gallery_files: List[Dict[Union[Literal["name"], Literal["data"]], str]] = [],
|
||||
use_preview_as_input: bool = False,
|
||||
generated_image: Optional[np.ndarray] = None,
|
||||
*args,
|
||||
) -> ControlNetUnit:
|
||||
unit_dict = {
|
||||
k: v
|
||||
for k, v in zip(
|
||||
vars(ControlNetUnit()).keys(),
|
||||
itertools.chain(
|
||||
[True, input_mode, batch_images, output_dir, loopback], args
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_mode: InputMode = InputMode.SIMPLE,
|
||||
batch_images: Optional[Union[str, List[external_code.InputImage]]] = None,
|
||||
output_dir: str = "",
|
||||
loopback: bool = False,
|
||||
merge_gallery_files: List[
|
||||
Dict[Union[Literal["name"], Literal["data"]], str]
|
||||
] = [],
|
||||
use_preview_as_input: bool = False,
|
||||
generated_image: Optional[np.ndarray] = None,
|
||||
enabled: bool = True,
|
||||
module: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
weight: float = 1.0,
|
||||
image: Optional[Dict[str, np.ndarray]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_preview_as_input and generated_image is not None:
|
||||
input_image = generated_image
|
||||
module = "none"
|
||||
else:
|
||||
input_image = image
|
||||
if use_preview_as_input and generated_image is not None:
|
||||
input_image = generated_image
|
||||
unit_dict["module"] = "none"
|
||||
else:
|
||||
input_image = unit_dict["image"]
|
||||
|
||||
if merge_gallery_files and input_mode == InputMode.MERGE:
|
||||
input_image = [
|
||||
{"image": read_image(file["name"])} for file in merge_gallery_files
|
||||
]
|
||||
if merge_gallery_files and input_mode == InputMode.MERGE:
|
||||
input_image = [
|
||||
{"image": read_image(file["name"])} for file in merge_gallery_files
|
||||
]
|
||||
|
||||
super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
|
||||
self.is_ui = True
|
||||
self.input_mode = input_mode
|
||||
self.batch_images = batch_images
|
||||
self.output_dir = output_dir
|
||||
self.loopback = loopback
|
||||
|
||||
def unfold_merged(self) -> List[ControlNetUnit]:
|
||||
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
|
||||
preprocessors that can accept multiple input images.
|
||||
"""
|
||||
if self.input_mode != InputMode.MERGE:
|
||||
return [copy(self)]
|
||||
|
||||
if self.accepts_multiple_inputs():
|
||||
self.input_mode = InputMode.SIMPLE
|
||||
return [copy(self)]
|
||||
|
||||
assert isinstance(self.image, list)
|
||||
result = []
|
||||
for image in self.image:
|
||||
unit = copy(self)
|
||||
unit.image = image["image"]
|
||||
unit.input_mode = InputMode.SIMPLE
|
||||
unit.weight = self.weight / len(self.image)
|
||||
result.append(unit)
|
||||
return result
|
||||
unit_dict["image"] = input_image
|
||||
return ControlNetUnit.from_dict(unit_dict)
|
||||
|
||||
|
||||
class ControlNetUiGroup(object):
|
||||
|
|
@ -221,7 +193,6 @@ class ControlNetUiGroup(object):
|
|||
def __init__(
|
||||
self,
|
||||
is_img2img: bool,
|
||||
default_unit: ControlNetUnit,
|
||||
photopea: Optional[Photopea],
|
||||
):
|
||||
# Whether callbacks have been registered.
|
||||
|
|
@ -230,13 +201,13 @@ class ControlNetUiGroup(object):
|
|||
self.ui_initialized: bool = False
|
||||
|
||||
self.is_img2img = is_img2img
|
||||
self.default_unit = default_unit
|
||||
self.default_unit = ControlNetUnit()
|
||||
self.photopea = photopea
|
||||
self.webcam_enabled = False
|
||||
self.webcam_mirrored = False
|
||||
|
||||
# Note: All gradio elements declared in `render` will be defined as member variable.
|
||||
# Update counter to trigger a force update of UiControlNetUnit.
|
||||
# Update counter to trigger a force update of ControlNetUnit.
|
||||
# This is useful when a field with no event subscriber available changes.
|
||||
# e.g. gr.Gallery, gr.State, etc.
|
||||
self.update_unit_counter = None
|
||||
|
|
@ -283,7 +254,6 @@ class ControlNetUiGroup(object):
|
|||
self.loopback = None
|
||||
self.use_preview_as_input = None
|
||||
self.openpose_editor = None
|
||||
self.preset_panel = None
|
||||
self.upload_independent_img_in_img2img = None
|
||||
self.image_upload_panel = None
|
||||
self.save_detected_map = None
|
||||
|
|
@ -330,11 +300,13 @@ class ControlNetUiGroup(object):
|
|||
tool="sketch",
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_input_image",
|
||||
elem_classes=["cnet-image"],
|
||||
brush_color=shared.opts.img2img_inpaint_mask_brush_color
|
||||
if hasattr(
|
||||
shared.opts, "img2img_inpaint_mask_brush_color"
|
||||
)
|
||||
else None,
|
||||
brush_color=(
|
||||
shared.opts.img2img_inpaint_mask_brush_color
|
||||
if hasattr(
|
||||
shared.opts, "img2img_inpaint_mask_brush_color"
|
||||
)
|
||||
else None
|
||||
),
|
||||
)
|
||||
self.image.preprocess = functools.partial(
|
||||
svg_preprocess, preprocess=self.image.preprocess
|
||||
|
|
@ -515,7 +487,11 @@ class ControlNetUiGroup(object):
|
|||
)
|
||||
|
||||
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
|
||||
self.type_filter = (gr.Dropdown if shared.opts.data.get("controlnet_control_type_dropdown", False) else gr.Radio)(
|
||||
self.type_filter = (
|
||||
gr.Dropdown
|
||||
if shared.opts.data.get("controlnet_control_type_dropdown", False)
|
||||
else gr.Radio
|
||||
)(
|
||||
Preprocessor.get_all_preprocessor_tags(),
|
||||
label="Control Type",
|
||||
value="All",
|
||||
|
|
@ -645,7 +621,7 @@ class ControlNetUiGroup(object):
|
|||
|
||||
self.loopback = gr.Checkbox(
|
||||
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
|
||||
value=self.default_unit.loopback,
|
||||
value=False,
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
|
||||
elem_classes="controlnet_loopback_checkbox",
|
||||
visible=False,
|
||||
|
|
@ -653,10 +629,6 @@ class ControlNetUiGroup(object):
|
|||
|
||||
self.advanced_weight_control.render()
|
||||
|
||||
self.preset_panel = ControlNetPresetUI(
|
||||
id_prefix=f"{elem_id_tabname}_{tabname}_"
|
||||
)
|
||||
|
||||
self.batch_image_dir_state = gr.State("")
|
||||
self.output_dir_state = gr.State("")
|
||||
unit_args = (
|
||||
|
|
@ -693,32 +665,13 @@ class ControlNetUiGroup(object):
|
|||
self.pulid_mode,
|
||||
)
|
||||
|
||||
unit = gr.State(self.default_unit)
|
||||
for comp in unit_args + (self.update_unit_counter,):
|
||||
event_subscribers = []
|
||||
if hasattr(comp, "edit"):
|
||||
event_subscribers.append(comp.edit)
|
||||
elif hasattr(comp, "click"):
|
||||
event_subscribers.append(comp.click)
|
||||
elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
|
||||
event_subscribers.append(comp.release)
|
||||
elif hasattr(comp, "change"):
|
||||
event_subscribers.append(comp.change)
|
||||
|
||||
if hasattr(comp, "clear"):
|
||||
event_subscribers.append(comp.clear)
|
||||
|
||||
for event_subscriber in event_subscribers:
|
||||
event_subscriber(
|
||||
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
|
||||
)
|
||||
|
||||
unit = gr.State(ControlNetUnit())
|
||||
(
|
||||
ControlNetUiGroup.a1111_context.img2img_submit_button
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_submit_button
|
||||
).click(
|
||||
fn=UiControlNetUnit,
|
||||
fn=create_ui_unit,
|
||||
inputs=list(unit_args),
|
||||
outputs=unit,
|
||||
queue=False,
|
||||
|
|
@ -803,10 +756,12 @@ class ControlNetUiGroup(object):
|
|||
def register_build_sliders(self):
|
||||
def build_sliders(module: str, pp: bool):
|
||||
preprocessor = Preprocessor.get_preprocessor(module)
|
||||
slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy()
|
||||
slider_resolution_kwargs = (
|
||||
preprocessor.slider_resolution.gradio_update_kwargs.copy()
|
||||
)
|
||||
|
||||
if pp:
|
||||
slider_resolution_kwargs['visible'] = False
|
||||
slider_resolution_kwargs["visible"] = False
|
||||
|
||||
grs = [
|
||||
gr.update(**slider_resolution_kwargs),
|
||||
|
|
@ -852,9 +807,7 @@ class ControlNetUiGroup(object):
|
|||
gr.Dropdown.update(
|
||||
value=default_option, choices=filtered_preprocessor_list
|
||||
),
|
||||
gr.Dropdown.update(
|
||||
value=default_model, choices=filtered_model_list
|
||||
),
|
||||
gr.Dropdown.update(value=default_model, choices=filtered_model_list),
|
||||
]
|
||||
|
||||
self.type_filter.change(
|
||||
|
|
@ -893,7 +846,9 @@ class ControlNetUiGroup(object):
|
|||
)
|
||||
|
||||
def register_run_annotator(self):
|
||||
def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str):
|
||||
def run_annotator(
|
||||
image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str
|
||||
):
|
||||
if image is None:
|
||||
return (
|
||||
gr.update(value=None, visible=True),
|
||||
|
|
@ -956,16 +911,16 @@ class ControlNetUiGroup(object):
|
|||
and shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
|
||||
),
|
||||
json_pose_callback=(
|
||||
json_acceptor.accept
|
||||
if is_openpose(module)
|
||||
else None
|
||||
json_acceptor.accept if is_openpose(module) else None
|
||||
),
|
||||
model=model,
|
||||
)
|
||||
|
||||
return (
|
||||
# Update to `generated_image`
|
||||
gr.update(value=result.display_images[0], visible=True, interactive=False),
|
||||
gr.update(
|
||||
value=result.display_images[0], visible=True, interactive=False
|
||||
),
|
||||
# preprocessor_preview
|
||||
gr.update(value=True),
|
||||
# openpose editor
|
||||
|
|
@ -980,12 +935,16 @@ class ControlNetUiGroup(object):
|
|||
self.processor_res,
|
||||
self.threshold_a,
|
||||
self.threshold_b,
|
||||
ControlNetUiGroup.a1111_context.img2img_w_slider
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_w_slider,
|
||||
ControlNetUiGroup.a1111_context.img2img_h_slider
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_h_slider,
|
||||
(
|
||||
ControlNetUiGroup.a1111_context.img2img_w_slider
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_w_slider
|
||||
),
|
||||
(
|
||||
ControlNetUiGroup.a1111_context.img2img_h_slider
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_h_slider
|
||||
),
|
||||
self.pixel_perfect,
|
||||
self.resize_mode,
|
||||
self.model,
|
||||
|
|
@ -1256,14 +1215,6 @@ class ControlNetUiGroup(object):
|
|||
self.model,
|
||||
)
|
||||
assert self.type_filter is not None
|
||||
self.preset_panel.register_callbacks(
|
||||
self,
|
||||
self.type_filter,
|
||||
*[
|
||||
getattr(self, key)
|
||||
for key in vars(ControlNetUnit()).keys()
|
||||
],
|
||||
)
|
||||
self.advanced_weight_control.register_callbacks(
|
||||
self.weight,
|
||||
self.advanced_weighting,
|
||||
|
|
|
|||
|
|
@ -1,305 +0,0 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from modules import scripts
|
||||
from modules.ui_components import ToolButton
|
||||
from internal_controlnet.external_code import ControlNetUnit
|
||||
from scripts.infotext import parse_unit, serialize_unit
|
||||
from scripts.logging import logger
|
||||
from scripts.supported_preprocessor import Preprocessor
|
||||
|
||||
save_symbol = "\U0001f4be" # 💾
|
||||
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
reset_symbol = "\U000021A9" # ↩
|
||||
|
||||
NEW_PRESET = "New Preset"
|
||||
|
||||
|
||||
def load_presets(preset_dir: str) -> Dict[str, str]:
|
||||
if not os.path.exists(preset_dir):
|
||||
os.makedirs(preset_dir)
|
||||
return {}
|
||||
|
||||
presets = {}
|
||||
for filename in os.listdir(preset_dir):
|
||||
if filename.endswith(".txt"):
|
||||
with open(os.path.join(preset_dir, filename), "r") as f:
|
||||
name = filename.replace(".txt", "")
|
||||
if name == NEW_PRESET:
|
||||
continue
|
||||
presets[name] = f.read()
|
||||
return presets
|
||||
|
||||
|
||||
def infer_control_type(module: str, model: str) -> str:
|
||||
p = Preprocessor.get_preprocessor(module)
|
||||
assert p is not None
|
||||
matched_tags = [
|
||||
tag
|
||||
for tag in p.tags
|
||||
if any(f in model.lower() for f in Preprocessor.tag_to_filters(tag))
|
||||
]
|
||||
if len(matched_tags) != 1:
|
||||
raise ValueError(
|
||||
f"Unable to infer control type from module {module} and model {model}"
|
||||
)
|
||||
return matched_tags[0]
|
||||
|
||||
|
||||
class ControlNetPresetUI(object):
|
||||
preset_directory = os.path.join(scripts.basedir(), "presets")
|
||||
presets = load_presets(preset_directory)
|
||||
|
||||
def __init__(self, id_prefix: str):
|
||||
with gr.Row():
|
||||
self.dropdown = gr.Dropdown(
|
||||
label="Presets",
|
||||
show_label=True,
|
||||
elem_classes=["cnet-preset-dropdown"],
|
||||
choices=ControlNetPresetUI.dropdown_choices(),
|
||||
value=NEW_PRESET,
|
||||
)
|
||||
self.reset_button = ToolButton(
|
||||
value=reset_symbol,
|
||||
elem_classes=["cnet-preset-reset"],
|
||||
tooltip="Reset preset",
|
||||
visible=False,
|
||||
)
|
||||
self.save_button = ToolButton(
|
||||
value=save_symbol,
|
||||
elem_classes=["cnet-preset-save"],
|
||||
tooltip="Save preset",
|
||||
)
|
||||
self.delete_button = ToolButton(
|
||||
value=delete_symbol,
|
||||
elem_classes=["cnet-preset-delete"],
|
||||
tooltip="Delete preset",
|
||||
)
|
||||
self.refresh_button = ToolButton(
|
||||
value=refresh_symbol,
|
||||
elem_classes=["cnet-preset-refresh"],
|
||||
tooltip="Refresh preset",
|
||||
)
|
||||
|
||||
with gr.Box(
|
||||
elem_classes=["popup-dialog", "cnet-preset-enter-name"],
|
||||
elem_id=f"{id_prefix}_cnet_preset_enter_name",
|
||||
) as self.name_dialog:
|
||||
with gr.Row():
|
||||
self.preset_name = gr.Textbox(
|
||||
label="Preset name",
|
||||
show_label=True,
|
||||
lines=1,
|
||||
elem_classes=["cnet-preset-name"],
|
||||
)
|
||||
self.confirm_preset_name = ToolButton(
|
||||
value=save_symbol,
|
||||
elem_classes=["cnet-preset-confirm-name"],
|
||||
tooltip="Save preset",
|
||||
)
|
||||
|
||||
def register_callbacks(
|
||||
self,
|
||||
uigroup,
|
||||
control_type: gr.Radio,
|
||||
*ui_states,
|
||||
):
|
||||
def apply_preset(name: str, control_type: str, *ui_states):
|
||||
if name == NEW_PRESET:
|
||||
return (
|
||||
gr.update(visible=False),
|
||||
*(
|
||||
(gr.skip(),)
|
||||
* (len(vars(ControlNetUnit()).keys()) + 1)
|
||||
),
|
||||
)
|
||||
|
||||
assert name in ControlNetPresetUI.presets
|
||||
|
||||
infotext = ControlNetPresetUI.presets[name]
|
||||
preset_unit = parse_unit(infotext)
|
||||
current_unit = ControlNetUnit(*ui_states)
|
||||
preset_unit.image = None
|
||||
current_unit.image = None
|
||||
|
||||
# Do not compare module param that are not used in preset.
|
||||
for module_param in ("processor_res", "threshold_a", "threshold_b"):
|
||||
if getattr(preset_unit, module_param) == -1:
|
||||
setattr(current_unit, module_param, -1)
|
||||
|
||||
# No update necessary.
|
||||
if vars(current_unit) == vars(preset_unit):
|
||||
return (
|
||||
gr.update(visible=False),
|
||||
*(
|
||||
(gr.skip(),)
|
||||
* (len(vars(ControlNetUnit()).keys()) + 1)
|
||||
),
|
||||
)
|
||||
|
||||
unit = preset_unit
|
||||
|
||||
try:
|
||||
new_control_type = infer_control_type(unit.module, unit.model)
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
new_control_type = control_type
|
||||
|
||||
return (
|
||||
gr.update(visible=True),
|
||||
gr.update(value=new_control_type),
|
||||
*[
|
||||
gr.update(value=value) if value is not None else gr.update()
|
||||
for value in vars(unit).values()
|
||||
],
|
||||
)
|
||||
|
||||
for element, action in (
|
||||
(self.dropdown, "change"),
|
||||
(self.reset_button, "click"),
|
||||
):
|
||||
getattr(element, action)(
|
||||
fn=apply_preset,
|
||||
inputs=[self.dropdown, control_type, *ui_states],
|
||||
outputs=[self.delete_button, control_type, *ui_states],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=lambda: gr.update(visible=False),
|
||||
inputs=None,
|
||||
outputs=[self.reset_button],
|
||||
)
|
||||
|
||||
def save_preset(name: str, *ui_states):
|
||||
if name == NEW_PRESET:
|
||||
return gr.update(visible=True), gr.update(), gr.update()
|
||||
|
||||
ControlNetPresetUI.save_preset(
|
||||
name, ControlNetUnit(*ui_states)
|
||||
)
|
||||
return (
|
||||
gr.update(), # name dialog
|
||||
gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name),
|
||||
gr.update(visible=False), # Reset button
|
||||
)
|
||||
|
||||
self.save_button.click(
|
||||
fn=save_preset,
|
||||
inputs=[self.dropdown, *ui_states],
|
||||
outputs=[self.name_dialog, self.dropdown, self.reset_button],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=None,
|
||||
_js=f"""
|
||||
(name) => {{
|
||||
if (name === "{NEW_PRESET}")
|
||||
popup(gradioApp().getElementById('{self.name_dialog.elem_id}'));
|
||||
}}""",
|
||||
inputs=[self.dropdown],
|
||||
)
|
||||
|
||||
def delete_preset(name: str):
|
||||
ControlNetPresetUI.delete_preset(name)
|
||||
return gr.Dropdown.update(
|
||||
choices=ControlNetPresetUI.dropdown_choices(),
|
||||
value=NEW_PRESET,
|
||||
), gr.update(visible=False)
|
||||
|
||||
self.delete_button.click(
|
||||
fn=delete_preset,
|
||||
inputs=[self.dropdown],
|
||||
outputs=[self.dropdown, self.reset_button],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.name_dialog.visible = False
|
||||
|
||||
def save_new_preset(new_name: str, *ui_states):
|
||||
if new_name == NEW_PRESET:
|
||||
logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'")
|
||||
return gr.update(visible=False), gr.update()
|
||||
|
||||
ControlNetPresetUI.save_preset(
|
||||
new_name, ControlNetUnit(*ui_states)
|
||||
)
|
||||
return gr.update(visible=False), gr.update(
|
||||
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
|
||||
)
|
||||
|
||||
self.confirm_preset_name.click(
|
||||
fn=save_new_preset,
|
||||
inputs=[self.preset_name, *ui_states],
|
||||
outputs=[self.name_dialog, self.dropdown],
|
||||
show_progress="hidden",
|
||||
).then(fn=None, _js="closePopup")
|
||||
|
||||
self.refresh_button.click(
|
||||
fn=ControlNetPresetUI.refresh_preset,
|
||||
inputs=None,
|
||||
outputs=[self.dropdown],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def update_reset_button(preset_name: str, *ui_states):
|
||||
if preset_name == NEW_PRESET:
|
||||
return gr.update(visible=False)
|
||||
|
||||
infotext = ControlNetPresetUI.presets[preset_name]
|
||||
preset_unit = parse_unit(infotext)
|
||||
current_unit = ControlNetUnit(*ui_states)
|
||||
preset_unit.image = None
|
||||
current_unit.image = None
|
||||
|
||||
# Do not compare module param that are not used in preset.
|
||||
for module_param in ("processor_res", "threshold_a", "threshold_b"):
|
||||
if getattr(preset_unit, module_param) == -1:
|
||||
setattr(current_unit, module_param, -1)
|
||||
|
||||
return gr.update(visible=vars(current_unit) != vars(preset_unit))
|
||||
|
||||
for ui_state in ui_states:
|
||||
if isinstance(ui_state, gr.Image):
|
||||
continue
|
||||
|
||||
for action in ("edit", "click", "change", "clear", "release"):
|
||||
if action == "release" and not isinstance(ui_state, gr.Slider):
|
||||
continue
|
||||
|
||||
if hasattr(ui_state, action):
|
||||
getattr(ui_state, action)(
|
||||
fn=update_reset_button,
|
||||
inputs=[self.dropdown, *ui_states],
|
||||
outputs=[self.reset_button],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def dropdown_choices() -> List[str]:
|
||||
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]
|
||||
|
||||
@staticmethod
|
||||
def save_preset(name: str, unit: ControlNetUnit):
|
||||
infotext = serialize_unit(unit)
|
||||
with open(
|
||||
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
|
||||
) as f:
|
||||
f.write(infotext)
|
||||
|
||||
ControlNetPresetUI.presets[name] = infotext
|
||||
|
||||
@staticmethod
|
||||
def delete_preset(name: str):
|
||||
if name not in ControlNetPresetUI.presets:
|
||||
return
|
||||
|
||||
del ControlNetPresetUI.presets[name]
|
||||
|
||||
file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt")
|
||||
if os.path.exists(file):
|
||||
os.unlink(file)
|
||||
|
||||
@staticmethod
|
||||
def refresh_preset():
|
||||
ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory)
|
||||
return gr.update(choices=ControlNetPresetUI.dropdown_choices())
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Any, List, NamedTuple
|
||||
from typing import List, NamedTuple
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
|
|
@ -224,19 +224,6 @@ class HiResFixOption(Enum):
|
|||
LOW_RES_ONLY = "Low res only"
|
||||
HIGH_RES_ONLY = "High res only"
|
||||
|
||||
@staticmethod
|
||||
def from_value(value: Any) -> "HiResFixOption":
|
||||
if isinstance(value, str) and value.startswith("HiResFixOption."):
|
||||
_, field = value.split(".")
|
||||
return getattr(HiResFixOption, field)
|
||||
if isinstance(value, str):
|
||||
return HiResFixOption(value)
|
||||
elif isinstance(value, int):
|
||||
return [x for x in HiResFixOption][value]
|
||||
else:
|
||||
assert isinstance(value, HiResFixOption)
|
||||
return value
|
||||
|
||||
|
||||
class InputMode(Enum):
|
||||
# Single image to a single ControlNet unit.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
from typing import List, Tuple
|
||||
from enum import Enum
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
|
|
@ -8,53 +8,6 @@ from internal_controlnet.external_code import ControlNetUnit
|
|||
from scripts.logging import logger
|
||||
|
||||
|
||||
def field_to_displaytext(fieldname: str) -> str:
|
||||
return " ".join([word.capitalize() for word in fieldname.split("_")])
|
||||
|
||||
|
||||
def displaytext_to_field(text: str) -> str:
|
||||
return "_".join([word.lower() for word in text.split(" ")])
|
||||
|
||||
|
||||
def parse_value(value: str) -> Union[str, float, int, bool]:
|
||||
if value in ("True", "False"):
|
||||
return value == "True"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value # Plain string.
|
||||
|
||||
|
||||
def serialize_unit(unit: ControlNetUnit) -> str:
|
||||
excluded_fields = ControlNetUnit.infotext_excluded_fields()
|
||||
|
||||
log_value = {
|
||||
field_to_displaytext(field): getattr(unit, field)
|
||||
for field in vars(ControlNetUnit()).keys()
|
||||
if field not in excluded_fields and getattr(unit, field) != -1
|
||||
# Note: exclude hidden slider values.
|
||||
}
|
||||
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
|
||||
logger.error(f"Unexpected tokens encountered:\n{log_value}")
|
||||
return ""
|
||||
|
||||
return ", ".join(f"{field}: {value}" for field, value in log_value.items())
|
||||
|
||||
|
||||
def parse_unit(text: str) -> ControlNetUnit:
|
||||
return ControlNetUnit(
|
||||
enabled=True,
|
||||
**{
|
||||
displaytext_to_field(key): parse_value(value)
|
||||
for item in text.split(",")
|
||||
for (key, value) in (item.strip().split(": "),)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Infotext(object):
|
||||
def __init__(self) -> None:
|
||||
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
|
||||
|
|
@ -74,11 +27,7 @@ class Infotext(object):
|
|||
iocomponents.
|
||||
"""
|
||||
unit_prefix = Infotext.unit_prefix(unit_index)
|
||||
for field in vars(ControlNetUnit()).keys():
|
||||
# Exclude image for infotext.
|
||||
if field == "image":
|
||||
continue
|
||||
|
||||
for field in ControlNetUnit.infotext_fields():
|
||||
# Every field in ControlNetUnit should have a cooresponding
|
||||
# IOComponent in ControlNetUiGroup.
|
||||
io_component = getattr(uigroup, field)
|
||||
|
|
@ -87,13 +36,11 @@ class Infotext(object):
|
|||
self.paste_field_names.append(component_locator)
|
||||
|
||||
@staticmethod
|
||||
def write_infotext(
|
||||
units: List[ControlNetUnit], p: StableDiffusionProcessing
|
||||
):
|
||||
def write_infotext(units: List[ControlNetUnit], p: StableDiffusionProcessing):
|
||||
"""Write infotext to `p`."""
|
||||
p.extra_generation_params.update(
|
||||
{
|
||||
Infotext.unit_prefix(i): serialize_unit(unit)
|
||||
Infotext.unit_prefix(i): unit.serialize()
|
||||
for i, unit in enumerate(units)
|
||||
if unit.enabled
|
||||
}
|
||||
|
|
@ -109,14 +56,19 @@ class Infotext(object):
|
|||
|
||||
assert isinstance(v, str), f"Expect string but got {v}."
|
||||
try:
|
||||
for field, value in vars(parse_unit(v)).items():
|
||||
if field == "image":
|
||||
for field, value in vars(ControlNetUnit.parse(v)).items():
|
||||
if field not in ControlNetUnit.infotext_fields():
|
||||
continue
|
||||
if value is None:
|
||||
logger.debug(f"InfoText: Skipping {field} because value is None.")
|
||||
logger.debug(
|
||||
f"InfoText: Skipping {field} because value is None."
|
||||
)
|
||||
continue
|
||||
|
||||
component_locator = f"{k} {field}"
|
||||
if isinstance(value, Enum):
|
||||
value = value.value
|
||||
|
||||
updates[component_locator] = value
|
||||
logger.debug(f"InfoText: Setting {component_locator} = {value}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import numpy as np
|
||||
import unittest.mock
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
|
@ -7,13 +8,17 @@ utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'u
|
|||
|
||||
from modules import processing, scripts, shared
|
||||
from internal_controlnet.external_code import ControlNetUnit
|
||||
from scripts import controlnet, external_code, batch_hijack
|
||||
from scripts import controlnet, batch_hijack
|
||||
|
||||
|
||||
batch_hijack.instance.undo_hijack()
|
||||
original_process_images_inner = processing.process_images_inner
|
||||
|
||||
|
||||
def create_unit(**kwargs) -> ControlNetUnit:
|
||||
return ControlNetUnit(enabled=True, **kwargs)
|
||||
|
||||
|
||||
class TestBatchHijack(unittest.TestCase):
|
||||
@unittest.mock.patch('modules.script_callbacks.on_script_unloaded')
|
||||
def setUp(self, on_script_unloaded_mock):
|
||||
|
|
@ -59,9 +64,18 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
is_cn_batch, batches, output_dir, _ = batch_hijack.get_cn_batches(self.p)
|
||||
batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.process_batch_callbacks, self.p, batches, output_dir)
|
||||
|
||||
batch_units = [unit for unit in self.p.script_args if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH]
|
||||
batch_units = [
|
||||
unit
|
||||
for unit in self.p.script_args
|
||||
if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH
|
||||
]
|
||||
# Convert iterator to list to avoid double eval of iterator exhausting
|
||||
# the iterator in following checks.
|
||||
for unit in batch_units:
|
||||
unit.batch_images = list(unit.batch_images)
|
||||
|
||||
if batch_units:
|
||||
self.assertEqual(min(len(list(unit.batch_images)) for unit in batch_units), len(batches))
|
||||
self.assertEqual(min(len(unit.batch_images) for unit in batch_units), len(batches))
|
||||
else:
|
||||
self.assertEqual(1, len(batches))
|
||||
|
||||
|
|
@ -74,15 +88,15 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
self.assertEqual(is_batch, False)
|
||||
|
||||
def test_get_cn_batches__1_simple(self):
|
||||
self.p.script_args.append(ControlNetUnit(image=get_dummy_image()))
|
||||
self.p.script_args.append(create_unit(image=get_dummy_image()))
|
||||
self.assert_get_cn_batches_works([
|
||||
[self.p.script_args[0].image],
|
||||
[get_dummy_image()],
|
||||
])
|
||||
|
||||
def test_get_cn_batches__2_simples(self):
|
||||
self.p.script_args.extend([
|
||||
ControlNetUnit(image=get_dummy_image(0)),
|
||||
ControlNetUnit(image=get_dummy_image(1)),
|
||||
create_unit(image=get_dummy_image(0)),
|
||||
create_unit(image=get_dummy_image(1)),
|
||||
])
|
||||
self.assert_get_cn_batches_works([
|
||||
[get_dummy_image(0)],
|
||||
|
|
@ -91,7 +105,7 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
|
||||
def test_get_cn_batches__1_batch(self):
|
||||
self.p.script_args.extend([
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
|
|
@ -108,14 +122,14 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
|
||||
def test_get_cn_batches__2_batches(self):
|
||||
self.p.script_args.extend([
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(2),
|
||||
|
|
@ -136,8 +150,8 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
|
||||
def test_get_cn_batches__2_mixed(self):
|
||||
self.p.script_args.extend([
|
||||
ControlNetUnit(image=get_dummy_image(0)),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(image=get_dummy_image(0)),
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(1),
|
||||
|
|
@ -158,8 +172,8 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
|
||||
def test_get_cn_batches__3_mixed(self):
|
||||
self.p.script_args.extend([
|
||||
ControlNetUnit(image=get_dummy_image(0)),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(image=get_dummy_image(0)),
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(1),
|
||||
|
|
@ -167,7 +181,7 @@ class TestGetControlNetBatchesWorks(unittest.TestCase):
|
|||
get_dummy_image(3),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(4),
|
||||
|
|
@ -243,14 +257,14 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
|
||||
def test_process_images__only_simple_units__forwards(self):
|
||||
self.p.script_args = [
|
||||
ControlNetUnit(image=get_dummy_image()),
|
||||
ControlNetUnit(image=get_dummy_image()),
|
||||
create_unit(image=get_dummy_image()),
|
||||
create_unit(image=get_dummy_image()),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=0)
|
||||
|
||||
def test_process_images__1_batch_1_unit__runs_1_batch(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(),
|
||||
|
|
@ -261,7 +275,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
|
||||
def test_process_images__2_batches_1_unit__runs_2_batches(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
|
|
@ -274,7 +288,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
def test_process_images__8_batches_1_unit__runs_8_batches(self):
|
||||
batch_count = 8
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[get_dummy_image(i) for i in range(batch_count)]
|
||||
),
|
||||
|
|
@ -283,11 +297,11 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
|
||||
def test_process_images__1_batch_2_units__runs_1_batch(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[get_dummy_image(0)]
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[get_dummy_image(1)]
|
||||
),
|
||||
|
|
@ -296,14 +310,14 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
|
||||
def test_process_images__2_batches_2_units__runs_2_batches(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(2),
|
||||
|
|
@ -315,7 +329,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
|
||||
def test_process_images__3_batches_2_mixed_units__runs_3_batches(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
|
|
@ -323,7 +337,7 @@ class TestProcessImagesPatchWorks(unittest.TestCase):
|
|||
get_dummy_image(2),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
create_unit(
|
||||
input_mode=batch_hijack.InputMode.SIMPLE,
|
||||
image=get_dummy_image(3),
|
||||
),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import importlib
|
|||
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
|
||||
|
||||
|
||||
from scripts import external_code
|
||||
from scripts.enums import ResizeMode
|
||||
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
|
||||
from internal_controlnet.external_code import ControlNetUnit
|
||||
|
|
@ -119,9 +118,7 @@ class TestScript(unittest.TestCase):
|
|||
"AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
sample_np_image = np.array(
|
||||
[[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8
|
||||
)
|
||||
sample_np_image = np.zeros(shape=[8, 8, 3], dtype=np.uint8)
|
||||
|
||||
def test_choose_input_image(self):
|
||||
with self.subTest(name="no image"):
|
||||
|
|
@ -139,7 +136,7 @@ class TestScript(unittest.TestCase):
|
|||
resize_mode=ResizeMode.OUTER_FIT,
|
||||
),
|
||||
unit=ControlNetUnit(
|
||||
image=TestScript.sample_base64_image,
|
||||
image=TestScript.sample_np_image,
|
||||
module="none",
|
||||
resize_mode=ResizeMode.INNER_FIT,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
import unittest
|
||||
import importlib
|
||||
|
||||
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
|
||||
|
||||
from scripts.infotext import parse_unit
|
||||
from scripts.external_code import ControlNetUnit
|
||||
|
||||
|
||||
class TestInfotext(unittest.TestCase):
|
||||
def test_parsing(self):
|
||||
infotext = (
|
||||
"Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, "
|
||||
"Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, "
|
||||
"Pixel Perfect: True, Control Mode: Balanced, Hr Option: Both, Save Detected Map: True"
|
||||
)
|
||||
self.assertEqual(
|
||||
vars(
|
||||
ControlNetUnit(
|
||||
module="inpaint_only+lama",
|
||||
model="control_v11p_sd15_inpaint [ebff9138]",
|
||||
weight=1,
|
||||
resize_mode="Resize and Fill",
|
||||
low_vram=False,
|
||||
guidance_start=0,
|
||||
guidance_end=1,
|
||||
pixel_perfect=True,
|
||||
control_mode="Balanced",
|
||||
hr_option="Both",
|
||||
save_detected_map=True,
|
||||
)
|
||||
),
|
||||
vars(parse_unit(infotext)),
|
||||
)
|
||||
|
|
@ -54,70 +54,6 @@ class TestExternalCodeWorking(unittest.TestCase):
|
|||
self.assert_update_in_place_ok()
|
||||
|
||||
|
||||
class TestControlNetUnitConversion(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dummy_image = 'base64...'
|
||||
self.input = {}
|
||||
self.expected = ControlNetUnit()
|
||||
|
||||
def assert_converts_to_expected(self):
|
||||
self.assertEqual(vars(external_code.to_processing_unit(self.input)), vars(self.expected))
|
||||
|
||||
def test_empty_dict_works(self):
|
||||
self.assert_converts_to_expected()
|
||||
|
||||
def test_image_works(self):
|
||||
self.input = {
|
||||
'image': self.dummy_image
|
||||
}
|
||||
self.expected = ControlNetUnit(image=self.dummy_image)
|
||||
self.assert_converts_to_expected()
|
||||
|
||||
def test_image_alias_works(self):
|
||||
self.input = {
|
||||
'input_image': self.dummy_image
|
||||
}
|
||||
self.expected = ControlNetUnit(image=self.dummy_image)
|
||||
self.assert_converts_to_expected()
|
||||
|
||||
def test_masked_image_works(self):
|
||||
self.input = {
|
||||
'image': self.dummy_image,
|
||||
'mask': self.dummy_image,
|
||||
}
|
||||
self.expected = ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image})
|
||||
self.assert_converts_to_expected()
|
||||
|
||||
|
||||
class TestControlNetUnitImageToDict(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dummy_image = utils.readImage("test/test_files/img2img_basic.png")
|
||||
self.input = ControlNetUnit()
|
||||
self.expected_image = external_code.to_base64_nparray(self.dummy_image)
|
||||
self.expected_mask = external_code.to_base64_nparray(self.dummy_image)
|
||||
|
||||
def assert_dict_is_valid(self):
|
||||
actual_dict = controlnet.image_dict_from_any(self.input.image)
|
||||
self.assertEqual(actual_dict['image'].tolist(), self.expected_image.tolist())
|
||||
self.assertEqual(actual_dict['mask'].tolist(), self.expected_mask.tolist())
|
||||
|
||||
def test_none(self):
|
||||
self.assertEqual(controlnet.image_dict_from_any(self.input.image), None)
|
||||
|
||||
def test_image_without_mask(self):
|
||||
self.input.image = self.dummy_image
|
||||
self.expected_mask = np.zeros_like(self.expected_image, dtype=np.uint8)
|
||||
self.assert_dict_is_valid()
|
||||
|
||||
def test_masked_image_tuple(self):
|
||||
self.input.image = (self.dummy_image, self.dummy_image,)
|
||||
self.assert_dict_is_valid()
|
||||
|
||||
def test_masked_image_dict(self):
|
||||
self.input.image = {'image': self.dummy_image, 'mask': self.dummy_image}
|
||||
self.assert_dict_is_valid()
|
||||
|
||||
|
||||
class TestPixelPerfectResolution(unittest.TestCase):
|
||||
def test_outer_fit(self):
|
||||
image = np.zeros((100, 100, 3))
|
||||
|
|
@ -136,37 +72,5 @@ class TestPixelPerfectResolution(unittest.TestCase):
|
|||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
class TestGetAllUnitsFrom(unittest.TestCase):
|
||||
def test_none(self):
|
||||
self.assertListEqual(external_code.get_all_units_from([None]), [])
|
||||
|
||||
def test_bool(self):
|
||||
self.assertListEqual(external_code.get_all_units_from([True]), [])
|
||||
|
||||
def test_inheritance(self):
|
||||
class Foo(ControlNetUnit):
|
||||
def __init__(self):
|
||||
super().__init__(self)
|
||||
self.bar = 'a'
|
||||
|
||||
foo = Foo()
|
||||
self.assertListEqual(external_code.get_all_units_from([foo]), [foo])
|
||||
|
||||
def test_dict(self):
|
||||
units = external_code.get_all_units_from([{}])
|
||||
self.assertGreater(len(units), 0)
|
||||
self.assertIsInstance(units[0], ControlNetUnit)
|
||||
|
||||
def test_unitlike(self):
|
||||
class Foo(object):
|
||||
""" bar """
|
||||
|
||||
foo = Foo()
|
||||
for key in vars(ControlNetUnit()).keys():
|
||||
setattr(foo, key, True)
|
||||
setattr(foo, 'bar', False)
|
||||
self.assertListEqual(external_code.get_all_units_from([foo]), [foo])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import unittest
|
||||
import importlib
|
||||
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils')
|
||||
|
||||
|
||||
from scripts import external_code
|
||||
from scripts.enums import ControlMode
|
||||
from internal_controlnet.external_code import ControlNetUnit
|
||||
|
||||
|
||||
class TestGetAllUnitsFrom(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.control_unit = {
|
||||
"module": "none",
|
||||
"model": utils.get_model("canny"),
|
||||
"image": utils.readImage("test/test_files/img2img_basic.png"),
|
||||
"resize_mode": 1,
|
||||
"low_vram": False,
|
||||
"processor_res": 64,
|
||||
"control_mode": ControlMode.BALANCED.value,
|
||||
}
|
||||
self.object_unit = ControlNetUnit(**self.control_unit)
|
||||
|
||||
def test_empty_converts(self):
|
||||
script_args = []
|
||||
units = external_code.get_all_units_from(script_args)
|
||||
self.assertListEqual(units, [])
|
||||
|
||||
def test_object_forwards(self):
|
||||
script_args = [self.object_unit]
|
||||
units = external_code.get_all_units_from(script_args)
|
||||
self.assertListEqual(units, [self.object_unit])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -143,16 +143,16 @@ def test_detect_default_param():
|
|||
dict(
|
||||
controlnet_input_images=[realistic_girl_face_img],
|
||||
controlnet_module="canny", # Canny does not require model download.
|
||||
controlnet_threshold_a=-1,
|
||||
controlnet_threshold_b=-1,
|
||||
controlnet_processor_res=-1,
|
||||
controlnet_threshold_a=-100,
|
||||
controlnet_threshold_b=-100,
|
||||
controlnet_processor_res=-100,
|
||||
),
|
||||
"default_param",
|
||||
)
|
||||
assert log_context.is_in_console_logs(
|
||||
[
|
||||
"[canny.processor_res] Invalid value(-1), using default value 512.",
|
||||
"[canny.threshold_a] Invalid value(-1.0), using default value 100.",
|
||||
"[canny.threshold_b] Invalid value(-1.0), using default value 200.",
|
||||
"[canny.processor_res] Invalid value(-100), using default value 512.",
|
||||
"[canny.threshold_a] Invalid value(-100.0), using default value 100.",
|
||||
"[canny.threshold_b] Invalid value(-100.0), using default value 200.",
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -169,14 +169,14 @@ def expect_same_image(img1, img2, diff_img_path: str) -> bool:
|
|||
|
||||
|
||||
default_unit = {
|
||||
"control_mode": 0,
|
||||
"control_mode": "Balanced",
|
||||
"enabled": True,
|
||||
"guidance_end": 1,
|
||||
"guidance_start": 0,
|
||||
"low_vram": False,
|
||||
"pixel_perfect": True,
|
||||
"processor_res": 512,
|
||||
"resize_mode": 1,
|
||||
"resize_mode": "Crop and Resize",
|
||||
"threshold_a": -1,
|
||||
"threshold_b": -1,
|
||||
"weight": 1,
|
||||
|
|
|
|||
|
|
@ -87,12 +87,13 @@ def test_invalid_param(gen_type, param_name):
|
|||
f"test_invalid_param{(gen_type, param_name)}",
|
||||
gen_type,
|
||||
payload_overrides={},
|
||||
unit_overrides={param_name: -1},
|
||||
unit_overrides={param_name: -100},
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
number = "-100" if param_name == "processor_res" else "-100.0"
|
||||
assert log_context.is_in_console_logs(
|
||||
[
|
||||
f"[canny.{param_name}] Invalid value(-1), using default value",
|
||||
f"[canny.{param_name}] Invalid value({number}), using default value",
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -192,7 +193,7 @@ def test_hr_option():
|
|||
"enable_hr": True,
|
||||
"denoising_strength": 0.75,
|
||||
},
|
||||
unit_overrides={"hr_option": "HiResFixOption.BOTH"},
|
||||
unit_overrides={"hr_option": "Both"},
|
||||
input_image=girl_img,
|
||||
).exec(expected_output_num=3)
|
||||
|
||||
|
|
@ -203,7 +204,7 @@ def test_hr_option_default():
|
|||
"test_hr_option_default",
|
||||
"txt2img",
|
||||
payload_overrides={"enable_hr": False},
|
||||
unit_overrides={"hr_option": "HiResFixOption.BOTH"},
|
||||
unit_overrides={"hr_option": "Both"},
|
||||
input_image=girl_img,
|
||||
).exec(expected_output_num=2)
|
||||
|
||||
|
|
|
|||
|
|
@ -295,13 +295,13 @@ def get_model(model_name: str) -> str:
|
|||
|
||||
|
||||
default_unit = {
|
||||
"control_mode": 0,
|
||||
"control_mode": "Balanced",
|
||||
"enabled": True,
|
||||
"guidance_end": 1,
|
||||
"guidance_start": 0,
|
||||
"pixel_perfect": True,
|
||||
"processor_res": 512,
|
||||
"resize_mode": 1,
|
||||
"resize_mode": "Crop and Resize",
|
||||
"threshold_a": -1,
|
||||
"threshold_b": -1,
|
||||
"weight": 1,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,241 @@
|
|||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
|
||||
from internal_controlnet.args import ControlNetUnit
|
||||
|
||||
H = W = 128
|
||||
|
||||
img1 = np.ones(shape=[H, W, 3], dtype=np.uint8)
|
||||
img2 = np.ones(shape=[H, W, 3], dtype=np.uint8) * 2
|
||||
mask_diff = np.ones(shape=[H - 1, W - 1, 3], dtype=np.uint8) * 2
|
||||
mask_2d = np.ones(shape=[H, W])
|
||||
img_bad_channel = np.ones(shape=[H, W, 2], dtype=np.uint8) * 2
|
||||
img_bad_dim = np.ones(shape=[1, H, W, 3], dtype=np.uint8) * 2
|
||||
ui_img_diff = np.ones(shape=[H - 1, W - 1, 4], dtype=np.uint8) * 2
|
||||
ui_img = np.ones(shape=[H, W, 4], dtype=np.uint8)
|
||||
tensor1 = torch.zeros(size=[1, 1], dtype=torch.float16)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def set_cls_funcs():
|
||||
ControlNetUnit.cls_match_model = lambda s: s in {
|
||||
"None",
|
||||
"model1",
|
||||
"model2",
|
||||
"control_v11p_sd15_inpaint [ebff9138]",
|
||||
}
|
||||
ControlNetUnit.cls_match_module = lambda s: s in {
|
||||
"none",
|
||||
"module1",
|
||||
"inpaint_only+lama",
|
||||
}
|
||||
ControlNetUnit.cls_decode_base64 = lambda s: {
|
||||
"b64img1": img1,
|
||||
"b64img2": img2,
|
||||
"b64mask_diff": mask_diff,
|
||||
}[s]
|
||||
ControlNetUnit.cls_torch_load_base64 = lambda s: {
|
||||
"b64tensor1": tensor1,
|
||||
}[s]
|
||||
ControlNetUnit.cls_get_preprocessor = lambda s: {
|
||||
"module1": MockPreprocessor(),
|
||||
"none": MockPreprocessor(),
|
||||
"inpaint_only+lama": MockPreprocessor(),
|
||||
}[s]
|
||||
|
||||
|
||||
def test_module_invalid(set_cls_funcs):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ControlNetUnit(module="foo")
|
||||
|
||||
assert "module(foo) not found in supported modules." in str(excinfo.value)
|
||||
|
||||
|
||||
def test_module_valid(set_cls_funcs):
|
||||
ControlNetUnit(module="module1")
|
||||
|
||||
|
||||
def test_model_invalid(set_cls_funcs):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ControlNetUnit(model="foo")
|
||||
|
||||
assert "model(foo) not found in supported models." in str(excinfo.value)
|
||||
|
||||
|
||||
def test_model_valid(set_cls_funcs):
|
||||
ControlNetUnit(model="model1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"d",
|
||||
[
|
||||
# API
|
||||
dict(image={"image": "b64img1"}),
|
||||
dict(image={"image": "b64img1", "mask": "b64img2"}),
|
||||
dict(image=["b64img1", "b64img2"]),
|
||||
dict(image=("b64img1", "b64img2")),
|
||||
dict(image=[{"image": "b64img1", "mask": "b64img2"}]),
|
||||
dict(image=[{"image": "b64img1"}]),
|
||||
dict(image=[{"image": "b64img1", "mask": None}]),
|
||||
dict(
|
||||
image=[
|
||||
{"image": "b64img1", "mask": "b64img2"},
|
||||
{"image": "b64img1", "mask": "b64img2"},
|
||||
]
|
||||
),
|
||||
dict(
|
||||
image=[
|
||||
{"image": "b64img1", "mask": None},
|
||||
{"image": "b64img1", "mask": "b64img2"},
|
||||
]
|
||||
),
|
||||
dict(
|
||||
image=[
|
||||
{"image": "b64img1"},
|
||||
{"image": "b64img1", "mask": "b64img2"},
|
||||
]
|
||||
),
|
||||
dict(image="b64img1", mask="b64img2"),
|
||||
dict(image="b64img1"),
|
||||
dict(image="b64img1", mask_image="b64img2"),
|
||||
dict(image=None),
|
||||
# UI
|
||||
dict(image=dict(image=img1)),
|
||||
dict(image=dict(image=img1, mask=img2)),
|
||||
# 2D mask should be accepted.
|
||||
dict(image=dict(image=img1, mask=mask_2d)),
|
||||
dict(image=img1, mask=mask_2d),
|
||||
],
|
||||
)
|
||||
def test_valid_image_formats(set_cls_funcs, d):
|
||||
ControlNetUnit(**d)
|
||||
unit = ControlNetUnit.from_dict(d)
|
||||
unit.get_input_images_rgba()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"d",
|
||||
[
|
||||
dict(image={"mask": "b64img1"}),
|
||||
dict(image={"foo": "b64img1", "bar": "b64img2"}),
|
||||
dict(image=["b64img1"]),
|
||||
dict(image=("b64img1", "b64img2", "b64img1")),
|
||||
dict(image=[]),
|
||||
dict(image=[{"mask": "b64img1"}]),
|
||||
dict(image=None, mask="b64img2"),
|
||||
# image & mask have different H x W
|
||||
dict(image="b64img1", mask="b64mask_diff"),
|
||||
],
|
||||
)
|
||||
def test_invalid_image_formats(set_cls_funcs, d):
|
||||
# Setting field will be fine.
|
||||
ControlNetUnit(**d)
|
||||
unit = ControlNetUnit.from_dict(d)
|
||||
# Error on eval.
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
unit.get_input_images_rgba()
|
||||
|
||||
|
||||
def test_mask_alias_conflict():
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
ControlNetUnit.from_dict(
|
||||
dict(
|
||||
image="b64img1",
|
||||
mask="b64img1",
|
||||
mask_image="b64img1",
|
||||
)
|
||||
),
|
||||
|
||||
|
||||
def test_resize_mode():
|
||||
ControlNetUnit(resize_mode="Just Resize")
|
||||
|
||||
|
||||
def test_weight():
|
||||
ControlNetUnit(weight=0.5)
|
||||
ControlNetUnit(weight=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(weight=-1)
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(weight=100)
|
||||
|
||||
|
||||
def test_start_end():
|
||||
ControlNetUnit(guidance_start=0.0, guidance_end=1.0)
|
||||
ControlNetUnit(guidance_start=0.5, guidance_end=1.0)
|
||||
ControlNetUnit(guidance_start=0.5, guidance_end=0.5)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(guidance_start=1.0, guidance_end=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(guidance_start=11)
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(guidance_end=11)
|
||||
|
||||
|
||||
def test_effective_region_mask():
|
||||
ControlNetUnit(effective_region_mask="b64img1")
|
||||
ControlNetUnit(effective_region_mask=None)
|
||||
ControlNetUnit(effective_region_mask=img1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(effective_region_mask=124)
|
||||
|
||||
|
||||
def test_ipadapter_input():
|
||||
ControlNetUnit(ipadapter_input=["b64tensor1"])
|
||||
ControlNetUnit(ipadapter_input="b64tensor1")
|
||||
ControlNetUnit(ipadapter_input=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ControlNetUnit(ipadapter_input=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockSlider:
|
||||
value: float = 1
|
||||
minimum: float = 0
|
||||
maximum: float = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockPreprocessor:
|
||||
slider_resolution = MockSlider()
|
||||
slider_1 = MockSlider()
|
||||
slider_2 = MockSlider()
|
||||
|
||||
|
||||
def test_preprocessor_sliders():
|
||||
unit = ControlNetUnit(enabled=True, module="none")
|
||||
assert unit.processor_res == 1
|
||||
assert unit.threshold_a == 1
|
||||
assert unit.threshold_b == 1
|
||||
|
||||
|
||||
def test_preprocessor_sliders_disabled():
|
||||
unit = ControlNetUnit(enabled=False, module="none")
|
||||
assert unit.processor_res == -1
|
||||
assert unit.threshold_a == -1
|
||||
assert unit.threshold_b == -1
|
||||
|
||||
|
||||
def test_infotext_parsing():
|
||||
infotext = (
|
||||
"Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, "
|
||||
"Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, "
|
||||
"Pixel Perfect: True, Control Mode: Balanced"
|
||||
)
|
||||
assert ControlNetUnit(
|
||||
enabled=True,
|
||||
module="inpaint_only+lama",
|
||||
model="control_v11p_sd15_inpaint [ebff9138]",
|
||||
weight=1,
|
||||
resize_mode="Resize and Fill",
|
||||
low_vram=False,
|
||||
guidance_start=0,
|
||||
guidance_end=1,
|
||||
pixel_perfect=True,
|
||||
control_mode="Balanced",
|
||||
) == ControlNetUnit.parse(infotext)
|
||||
Loading…
Reference in New Issue