sd-webui-controlnet/internal_controlnet/args.py

474 lines
16 KiB
Python

from __future__ import annotations
import os
import torch
import numpy as np
from typing import Optional, List, Annotated, ClassVar, Callable, Any, Tuple, Union
from pydantic import BaseModel, validator, root_validator, Field
from PIL import Image
from logging import Logger
from copy import copy
from enum import Enum
from scripts.enums import (
InputMode,
ResizeMode,
ControlMode,
HiResFixOption,
PuLIDMode,
)
from annotator.util import HWC3
def _unimplemented_func(*args, **kwargs):
raise NotImplementedError("Not implemented.")
def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])
def displaytext_to_field(text: str) -> str:
return "_".join([word.lower() for word in text.split(" ")])
def serialize_value(value) -> str:
if isinstance(value, Enum):
return value.value
return str(value)
def parse_value(value: str) -> Union[str, float, int, bool]:
if value in ("True", "False"):
return value == "True"
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value # Plain string.
class ControlNetUnit(BaseModel):
"""
Represents an entire ControlNet processing unit.
"""
class Config:
arbitrary_types_allowed = True
extra = "ignore"
cls_match_module: ClassVar[Callable[[str], bool]] = _unimplemented_func
cls_match_model: ClassVar[Callable[[str], bool]] = _unimplemented_func
cls_decode_base64: ClassVar[Callable[[str], np.ndarray]] = _unimplemented_func
cls_torch_load_base64: ClassVar[Callable[[Any], torch.Tensor]] = _unimplemented_func
cls_get_preprocessor: ClassVar[Callable[[str], Any]] = _unimplemented_func
cls_logger: ClassVar[Logger] = Logger("ControlNetUnit")
# UI only fields.
is_ui: bool = False
input_mode: InputMode = InputMode.SIMPLE
batch_images: Optional[Any] = None
output_dir: str = ""
loopback: bool = False
# General fields.
enabled: bool = False
module: str = "none"
@validator("module", always=True, pre=True)
def check_module(cls, value: str) -> str:
if not ControlNetUnit.cls_match_module(value):
raise ValueError(f"module({value}) not found in supported modules.")
return value
model: str = "None"
@validator("model", always=True, pre=True)
def check_model(cls, value: str) -> str:
if not ControlNetUnit.cls_match_model(value):
raise ValueError(f"model({value}) not found in supported models.")
return value
weight: Annotated[float, Field(ge=0.0, le=2.0)] = 1.0
# The image to be used for this ControlNetUnit.
image: Optional[Any] = None
resize_mode: ResizeMode = ResizeMode.INNER_FIT
@validator("resize_mode", always=True, pre=True)
def check_resize_mode(cls, value) -> ResizeMode:
resize_mode_aliases = {
"Inner Fit (Scale to Fit)": "Crop and Resize",
"Outer Fit (Shrink to Fit)": "Resize and Fill",
"Scale to Fit (Inner Fit)": "Crop and Resize",
"Envelope (Outer Fit)": "Resize and Fill",
}
if isinstance(value, str):
return ResizeMode(resize_mode_aliases.get(value, value))
assert isinstance(value, ResizeMode)
return value
low_vram: bool = False
processor_res: int = -1
threshold_a: float = -1
threshold_b: float = -1
@root_validator
def bound_check_params(cls, values: dict) -> dict:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit' in place.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
"""
enabled = values.get("enabled")
if not enabled:
return values
module = values.get("module")
if not module:
return values
preprocessor = cls.cls_get_preprocessor(module)
assert preprocessor is not None
for unit_param, param in zip(
("processor_res", "threshold_a", "threshold_b"),
("slider_resolution", "slider_1", "slider_2"),
):
value = values.get(unit_param)
cfg = getattr(preprocessor, param)
if value < cfg.minimum or value > cfg.maximum:
values[unit_param] = cfg.value
# Only report warning when non-default value is used.
if value != -1:
cls.cls_logger.info(
f"[{module}.{unit_param}] Invalid value({value}), using default value {cfg.value}."
)
return values
guidance_start: Annotated[float, Field(ge=0.0, le=1.0)] = 0.0
guidance_end: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0
@root_validator
def guidance_check(cls, values: dict) -> dict:
start = values.get("guidance_start")
end = values.get("guidance_end")
if start > end:
raise ValueError(f"guidance_start({start}) > guidance_end({end})")
return values
pixel_perfect: bool = False
control_mode: ControlMode = ControlMode.BALANCED
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
inpaint_crop_input_image: bool = True
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: HiResFixOption = HiResFixOption.BOTH
# Whether save the detected map of this unit. Setting this option to False prevents saving the
# detected map or sending detected map along with generated images via API.
# Currently the option is only accessible in API calls.
save_detected_map: bool = True
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# For IPAdapter
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
# The effective region mask that unit's effect should be restricted to.
effective_region_mask: Optional[np.ndarray] = None
@validator("effective_region_mask", pre=True)
def parse_effective_region_mask(cls, value) -> np.ndarray:
if isinstance(value, str):
return cls.cls_decode_base64(value)
assert isinstance(value, np.ndarray) or value is None
return value
# The weight mode for PuLID.
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
# ------- API only fields -------
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
# Currently the option is only accessible in API calls.
ipadapter_input: Optional[List[torch.Tensor]] = None
@validator("ipadapter_input", pre=True)
def parse_ipadapter_input(cls, value) -> Optional[List[torch.Tensor]]:
if value is None:
return None
if isinstance(value, str):
value = [value]
result = [cls.cls_torch_load_base64(b) for b in value]
assert result, "input cannot be empty"
return result
# The mask to be used on top of the image.
mask: Optional[Any] = None
# AnimateDiff compatibility fields.
# TODO: Find a better way in AnimateDiff to deal with these extra fields.
batch_mask_dir: Optional[str] = None
animatediff_batch: bool = False
batch_modifiers: list = []
batch_image_files: list = []
@property
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter-auto",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"ip-adapter_pulid",
"instant_id_face_embedding",
)
@property
def is_animate_diff_batch(self) -> bool:
return getattr(self, "animatediff_batch", False)
@property
def uses_clip(self) -> bool:
"""Whether this unit uses clip preprocessor."""
return any(
(
("ip-adapter" in self.module and "face_id" not in self.module),
self.module
in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
)
)
@property
def is_inpaint(self) -> bool:
return "inpaint" in self.module
def get_actual_preprocessors(self) -> List[Any]:
p = ControlNetUnit.cls_get_preprocessor(self.module)
# Map "ip-adapter-auto" to actual preprocessor.
if self.module == "ip-adapter-auto":
p = p.get_preprocessor_by_model(self.model)
# Add all dependencies.
return [p] + [
ControlNetUnit.cls_get_preprocessor(dep) for dep in p.preprocessor_deps
]
@classmethod
def parse_image(cls, image) -> np.ndarray:
if isinstance(image, np.ndarray):
np_image = image
elif isinstance(image, str):
# Necessary for batch.
if os.path.exists(image):
np_image = np.array(Image.open(image)).astype("uint8")
else:
np_image = cls.cls_decode_base64(image)
else:
raise ValueError(f"Unrecognized image format {image}.")
# Convert following image shapes to shape [H, W, C=3].
# - [H, W]
# - [H, W, 1]
# - [H, W, 4]
np_image = HWC3(np_image)
assert np_image.ndim == 3
assert np_image.shape[2] == 3
return np_image
@classmethod
def combine_image_and_mask(
cls, np_image: np.ndarray, np_mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""RGB + Alpha(Optional) => RGBA"""
# TODO: Change protocol to use 255 as A channel value.
# Note: mask is by default zeros, as both inpaint and
# clip mask does extra work on masked area.
np_mask = (np.zeros_like(np_image) if np_mask is None else np_mask)[:, :, 0:1]
if np_image.shape[:2] != np_mask.shape[:2]:
raise ValueError(
f"image shape ({np_image.shape[:2]}) not aligned with mask shape ({np_mask.shape[:2]})"
)
return np.concatenate([np_image, np_mask], axis=2) # [H, W, 4]
@classmethod
def legacy_field_alias(cls, values: dict) -> dict:
ext_compat_keys = {
"guidance": "guidance_end",
"lowvram": "low_vram",
"input_image": "image",
}
for alias, key in ext_compat_keys.items():
if alias in values:
assert key not in values, f"Conflict of field '{alias}' and '{key}'"
values[key] = values[alias]
cls.cls_logger.warn(
f"Deprecated alias '{alias}' detected. This field will be removed on 2024-06-01"
f"Please use '{key}' instead."
)
return values
@classmethod
def mask_alias(cls, values: dict) -> dict:
"""
Field "mask_image" is the alias of field "mask".
This is for compatibility with SD Forge API.
"""
mask_image = values.get("mask_image")
mask = values.get("mask")
if mask_image is not None:
if mask is not None:
raise ValueError("Cannot specify both 'mask' and 'mask_image'!")
values["mask"] = mask_image
return values
def get_input_images_rgba(self) -> Optional[List[np.ndarray]]:
"""
RGBA images with potentially different size.
Why we cannot have [B, H, W, C=4] here is that calculation of final
resolution requires generation target's dimensions.
Parse image with following formats.
API
- image = {"image": base64image, "mask": base64image,}
- image = [image, mask]
- image = (image, mask)
- image = [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...]
- image = base64image, mask = base64image
UI:
- image = {"image": np_image, "mask": np_image,}
- image = np_image, mask = np_image
"""
init_image = self.image
init_mask = self.mask
if init_image is None:
assert init_mask is None
return None
if isinstance(init_image, (list, tuple)):
if not init_image:
raise ValueError(f"{init_image} is not a valid 'image' field value")
if isinstance(init_image[0], dict):
# [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...]
images = init_image
else:
assert len(init_image) == 2
# [image, mask]
# (image, mask)
images = [
{
"image": init_image[0],
"mask": init_image[1],
}
]
elif isinstance(init_image, dict):
# {"image": ..., "mask": ...}
images = [init_image]
elif isinstance(init_image, (str, np.ndarray)):
# image = base64image, mask = base64image
images = [
{
"image": init_image,
"mask": init_mask,
}
]
else:
raise ValueError(f"Unrecognized image field {init_image}")
np_images = []
for image_dict in images:
assert isinstance(image_dict, dict)
image = image_dict.get("image")
mask = image_dict.get("mask")
assert image is not None
np_image = self.parse_image(image)
np_mask = self.parse_image(mask) if mask is not None else None
np_images.append(
self.combine_image_and_mask(np_image, np_mask)
) # [H, W, 4]
return np_images
@classmethod
def from_dict(cls, values: dict) -> ControlNetUnit:
values = copy(values)
values = cls.legacy_field_alias(values)
values = cls.mask_alias(values)
return ControlNetUnit(**values)
@classmethod
def from_infotext_args(cls, *args) -> ControlNetUnit:
assert len(args) == len(ControlNetUnit.infotext_fields())
return cls.from_dict(
{k: v for k, v in zip(ControlNetUnit.infotext_fields(), args)}
)
@staticmethod
def infotext_fields() -> Tuple[str]:
"""Fields that should be included in infotext.
You should define a Gradio element with exact same name in ControlNetUiGroup
as well, so that infotext can wire the value to correct field when pasting
infotext.
"""
return (
"module",
"model",
"weight",
"resize_mode",
"processor_res",
"threshold_a",
"threshold_b",
"guidance_start",
"guidance_end",
"pixel_perfect",
"control_mode",
)
def serialize(self) -> str:
"""Serialize the unit for infotext."""
infotext_dict = {
field_to_displaytext(field): serialize_value(getattr(self, field))
for field in ControlNetUnit.infotext_fields()
}
if not all(
"," not in str(v) and ":" not in str(v) for v in infotext_dict.values()
):
self.cls_logger.error(f"Unexpected tokens encountered:\n{infotext_dict}")
return ""
return ", ".join(f"{field}: {value}" for field, value in infotext_dict.items())
@classmethod
def parse(cls, text: str) -> ControlNetUnit:
return ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
for item in text.split(",")
for (key, value) in (item.strip().split(": "),)
},
)