Port SDForge's preprocessor structure (#2754)

* wip

* wip2

* wip3

* fix issues

* Add scribble xdog

* port legacy processors

* nit

* Add tests

* Fix modules test

* Add back normal_dsine

* Remove legacy code

* Remove code

* Add tests

* rename param

* Linter ignore

* fix is_image

* fix is_image

* nit

* nit

* Better assertion message

* Add back ip-adapter-auto

* Add test

* Fix various tag matching

* fix

* Add back preprocessor cache

* Add back sparse ctrl

* fix test failure

* Add log
pull/2763/head
Chenlei Hu 2024-04-17 22:28:56 -04:00 committed by GitHub
parent bbcae309d1
commit 442398bb9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2315 additions and 1129 deletions

View File

@ -9,7 +9,7 @@ from annotator.annotator_path import models_path
import torchvision.transforms as transforms import torchvision.transforms as transforms
import dsine.utils.utils as utils import dsine.utils.utils as utils
from dsine.models.dsine import DSINE from dsine.models.dsine import DSINE
from scripts.processor import resize_image_with_pad from scripts.utils import resize_image_with_pad
class NormalDsineDetector: class NormalDsineDetector:

View File

@ -10,9 +10,9 @@ import numpy as np
from modules import scripts, processing, shared from modules import scripts, processing, shared
from modules.safe import unsafe_torch_load from modules.safe import unsafe_torch_load
from scripts import global_state from scripts import global_state
from scripts.processor import preprocessor_sliders_config, model_free_preprocessors
from scripts.logging import logger from scripts.logging import logger
from scripts.enums import HiResFixOption from scripts.enums import HiResFixOption
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules.api import api from modules.api import api
@ -56,10 +56,10 @@ class ResizeMode(Enum):
resize_mode_aliases = { resize_mode_aliases = {
'Inner Fit (Scale to Fit)': 'Crop and Resize', "Inner Fit (Scale to Fit)": "Crop and Resize",
'Outer Fit (Shrink to Fit)': 'Resize and Fill', "Outer Fit (Shrink to Fit)": "Resize and Fill",
'Scale to Fit (Inner Fit)': 'Crop and Resize', "Scale to Fit (Inner Fit)": "Crop and Resize",
'Envelope (Outer Fit)': 'Resize and Fill', "Envelope (Outer Fit)": "Resize and Fill",
} }
@ -72,7 +72,9 @@ def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
return ResizeMode.RESIZE return ResizeMode.RESIZE
if value >= len(ResizeMode): if value >= len(ResizeMode):
logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') logger.warning(
f"Unrecognized ResizeMode int value {value}. Fall back to RESIZE."
)
return ResizeMode.RESIZE return ResizeMode.RESIZE
return [e for e in ResizeMode][value] return [e for e in ResizeMode][value]
@ -159,6 +161,7 @@ class ControlNetUnit:
""" """
Represents an entire ControlNet processing unit. Represents an entire ControlNet processing unit.
""" """
enabled: bool = True enabled: bool = True
module: str = "none" module: str = "none"
model: str = "None" model: str = "None"
@ -242,10 +245,13 @@ class ControlNetUnit:
@property @property
def uses_clip(self) -> bool: def uses_clip(self) -> bool:
"""Whether this unit uses clip preprocessor.""" """Whether this unit uses clip preprocessor."""
return any(( return any(
(
("ip-adapter" in self.module and "face_id" not in self.module), ("ip-adapter" in self.module and "face_id" not in self.module),
self.module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"), self.module
)) in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
)
)
@property @property
def is_inpaint(self) -> bool: def is_inpaint(self) -> bool:
@ -257,18 +263,18 @@ class ControlNetUnit:
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative. their default values if negative.
""" """
cfg = preprocessor_sliders_config.get(global_state.get_module_basename(self.module), []) preprocessor = Preprocessor.get_preprocessor(self.module)
defaults = { for unit_param, param in zip(
param: cfg_default['value'] ("processor_res", "threshold_a", "threshold_b"),
for param, cfg_default in zip( ("slider_resolution", "slider_1", "slider_2"),
("processor_res", 'threshold_a', 'threshold_b'), cfg) ):
if cfg_default is not None value = getattr(self, unit_param)
} cfg: PreprocessorParameter = getattr(preprocessor, param)
for param, default_value in defaults.items():
value = getattr(self, param)
if value < 0: if value < 0:
setattr(self, param, default_value) setattr(self, unit_param, cfg.value)
logger.info(f'[{self.module}.{param}] Invalid value({value}), using default value {default_value}.') logger.info(
f"[{self.module}.{unit_param}] Invalid value({value}), using default value {cfg.value}."
)
def to_base64_nparray(encoding: str): def to_base64_nparray(encoding: str):
@ -276,10 +282,12 @@ def to_base64_nparray(encoding: str):
Convert a base64 image into the image type the extension uses Convert a base64 image into the image type the extension uses
""" """
return np.array(api.decode_base64_to_image(encoding)).astype('uint8') return np.array(api.decode_base64_to_image(encoding)).astype("uint8")
def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: def get_all_units_in_processing(
p: processing.StableDiffusionProcessing,
) -> List[ControlNetUnit]:
""" """
Fetch ControlNet processing units from a StableDiffusionProcessing. Fetch ControlNet processing units from a StableDiffusionProcessing.
""" """
@ -287,7 +295,9 @@ def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List
return get_all_units(p.scripts, p.script_args) return get_all_units(p.scripts, p.script_args)
def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: def get_all_units(
script_runner: scripts.ScriptRunner, script_args: List[Any]
) -> List[ControlNetUnit]:
""" """
Fetch ControlNet processing units from an existing script runner. Fetch ControlNet processing units from an existing script runner.
Use this function to fetch units from the list of all scripts arguments. Use this function to fetch units from the list of all scripts arguments.
@ -295,7 +305,7 @@ def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -
cn_script = find_cn_script(script_runner) cn_script = find_cn_script(script_runner)
if cn_script: if cn_script:
return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) return get_all_units_from(script_args[cn_script.args_from : cn_script.args_to])
return [] return []
@ -307,22 +317,19 @@ def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]:
""" """
def is_stale_unit(script_arg: Any) -> bool: def is_stale_unit(script_arg: Any) -> bool:
""" Returns whether the script_arg is potentially an stale version of """Returns whether the script_arg is potentially an stale version of
ControlNetUnit created before module reload.""" ControlNetUnit created before module reload."""
return ( return "ControlNetUnit" in type(script_arg).__name__ and not isinstance(
'ControlNetUnit' in type(script_arg).__name__ and script_arg, ControlNetUnit
not isinstance(script_arg, ControlNetUnit)
) )
def is_controlnet_unit(script_arg: Any) -> bool: def is_controlnet_unit(script_arg: Any) -> bool:
""" Returns whether the script_arg is ControlNetUnit or anything that """Returns whether the script_arg is ControlNetUnit or anything that
can be treated like ControlNetUnit. """ can be treated like ControlNetUnit."""
return ( return isinstance(script_arg, (ControlNetUnit, dict)) or (
isinstance(script_arg, (ControlNetUnit, dict)) or hasattr(script_arg, "__dict__")
( and set(vars(ControlNetUnit()).keys()).issubset(
hasattr(script_arg, '__dict__') and set(vars(script_arg).keys())
set(vars(ControlNetUnit()).keys()).issubset(
set(vars(script_arg).keys()))
) )
) )
@ -334,7 +341,8 @@ def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]:
if not all_units: if not all_units:
logger.warning( logger.warning(
"No ControlNetUnit detected in args. It is very likely that you are having an extension conflict." "No ControlNetUnit detected in args. It is very likely that you are having an extension conflict."
f"Here are args received by ControlNet: {script_args}.") f"Here are args received by ControlNet: {script_args}."
)
if any(is_stale_unit(script_arg) for script_arg in script_args): if any(is_stale_unit(script_arg) for script_arg in script_args):
logger.debug( logger.debug(
"Stale version of ControlNetUnit detected. The ControlNetUnit received" "Stale version of ControlNetUnit detected. The ControlNetUnit received"
@ -346,7 +354,9 @@ def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]:
return all_units return all_units
def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]: def get_single_unit_from(
script_args: List[Any], index: int = 0
) -> Optional[ControlNetUnit]:
""" """
Fetch a single ControlNet processing unit from ControlNet script arguments. Fetch a single ControlNet processing unit from ControlNet script arguments.
The list must not contain script positional arguments. It must only contain processing units. The list must not contain script positional arguments. It must only contain processing units.
@ -379,10 +389,10 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
""" """
ext_compat_keys = { ext_compat_keys = {
'guessmode': 'guess_mode', "guessmode": "guess_mode",
'guidance': 'guidance_end', "guidance": "guidance_end",
'lowvram': 'low_vram', "lowvram": "low_vram",
'input_image': 'image' "input_image": "image",
} }
if isinstance(unit, dict): if isinstance(unit, dict):
@ -390,20 +400,24 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
# Handle mask # Handle mask
mask = None mask = None
if 'mask' in unit: if "mask" in unit:
mask = unit['mask'] mask = unit["mask"]
del unit['mask'] del unit["mask"]
if "mask_image" in unit: if "mask_image" in unit:
mask = unit["mask_image"] mask = unit["mask_image"]
del unit["mask_image"] del unit["mask_image"]
if 'image' in unit and not isinstance(unit['image'], dict): if "image" in unit and not isinstance(unit["image"], dict):
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[ unit["image"] = (
'image'] else None {"image": unit["image"], "mask": mask}
if mask is not None
else unit["image"] if unit["image"] else None
)
# Parse ipadapter_input # Parse ipadapter_input
if "ipadapter_input" in unit: if "ipadapter_input" in unit:
def decode_base64(b: str) -> torch.Tensor: def decode_base64(b: str) -> torch.Tensor:
decoded_bytes = base64.b64decode(b) decoded_bytes = base64.b64decode(b)
return unsafe_torch_load(io.BytesIO(decoded_bytes)) return unsafe_torch_load(io.BytesIO(decoded_bytes))
@ -411,12 +425,18 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
if isinstance(unit["ipadapter_input"], str): if isinstance(unit["ipadapter_input"], str):
unit["ipadapter_input"] = [unit["ipadapter_input"]] unit["ipadapter_input"] = [unit["ipadapter_input"]]
unit["ipadapter_input"] = [decode_base64(b) for b in unit["ipadapter_input"]] unit["ipadapter_input"] = [
decode_base64(b) for b in unit["ipadapter_input"]
]
if 'guess_mode' in unit: if "guess_mode" in unit:
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') logger.warning(
"Guess Mode is removed since 1.1.136. Please use Control Mode instead."
)
unit = ControlNetUnit(**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()}) unit = ControlNetUnit(
**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()}
)
# temporary, check #602 # temporary, check #602
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' # assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]'
@ -464,13 +484,17 @@ def update_cn_script(
# fill in remaining parameters to satisfy max models, just in case script needs it. # fill in remaining parameters to satisfy max models, just in case script needs it.
max_models = shared.opts.data.get("control_net_unit_count", 3) max_models = shared.opts.data.get("control_net_unit_count", 3)
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(
max_models - len(cn_units), 0
)
cn_script_args_diff = 0 cn_script_args_diff = 0
for script in script_runner.alwayson_scripts: for script in script_runner.alwayson_scripts:
if script is cn_script: if script is cn_script:
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) cn_script_args_diff = len(cn_units) - (
updated_script_args[script.args_from:script.args_to] = cn_units cn_script.args_to - cn_script.args_from
)
updated_script_args[script.args_from : script.args_to] = cn_units
script.args_to = script.args_from + len(cn_units) script.args_to = script.args_from + len(cn_units)
else: else:
script.args_from += cn_script_args_diff script.args_from += cn_script_args_diff
@ -503,13 +527,17 @@ def update_cn_script_in_place(
# fill in remaining parameters to satisfy max models, just in case script needs it. # fill in remaining parameters to satisfy max models, just in case script needs it.
max_models = shared.opts.data.get("control_net_unit_count", 3) max_models = shared.opts.data.get("control_net_unit_count", 3)
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(
max_models - len(cn_units), 0
)
cn_script_args_diff = 0 cn_script_args_diff = 0
for script in script_runner.alwayson_scripts: for script in script_runner.alwayson_scripts:
if script is cn_script: if script is cn_script:
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) cn_script_args_diff = len(cn_units) - (
script_args[script.args_from:script.args_to] = cn_units cn_script.args_to - cn_script.args_from
)
script_args[script.args_from : script.args_to] = cn_units
script.args_to = script.args_from + len(cn_units) script.args_to = script.args_from + len(cn_units)
else: else:
script.args_from += cn_script_args_diff script.args_from += cn_script_args_diff
@ -539,13 +567,10 @@ def get_modules(alias_names: bool = False) -> List[str]:
Keyword arguments: Keyword arguments:
alias_names -- Whether to get the ui alias names instead of internal keys alias_names -- Whether to get the ui alias names instead of internal keys
""" """
return [
modules = list(global_state.cn_preprocessor_modules.keys()) (p.label if alias_names else p.name)
for p in Preprocessor.get_sorted_preprocessors()
if alias_names: ]
modules = [global_state.preprocessor_aliases.get(module, module) for module in modules]
return modules
def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]: def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]:
@ -562,17 +587,22 @@ def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]:
_module_list_alias = get_modules(True) _module_list_alias = get_modules(True)
_output_list = _module_list if not alias_names else _module_list_alias _output_list = _module_list if not alias_names else _module_list_alias
for index, module in enumerate(_output_list): for module_name in _output_list:
if _module_list[index] in preprocessor_sliders_config: preprocessor = Preprocessor.get_preprocessor(module_name)
_module_detail[module] = { assert preprocessor is not None
"model_free": module in model_free_preprocessors, _module_detail[module_name] = dict(
"sliders": preprocessor_sliders_config[_module_list[index]] model_free=preprocessor.do_not_need_model,
} sliders=[
else: s.api_json
_module_detail[module] = { for s in (
"model_free": False, preprocessor.slider_resolution,
"sliders": [] preprocessor.slider_1,
} preprocessor.slider_2,
preprocessor.slider_3,
)
if s.visible
],
)
return _module_detail return _module_detail
@ -595,4 +625,4 @@ def is_cn_script(script: scripts.Script) -> bool:
Determine whether `script` is a ControlNet script. Determine whether `script` is a ControlNet script.
""" """
return script.title().lower() == 'controlnet' return script.title().lower() == "controlnet"

View File

@ -8,8 +8,9 @@ exclude = [
"web_tests", "web_tests",
"example", "example",
"extract_controlnet_diff.py", "extract_controlnet_diff.py",
"scripts/global_state.py",
"scripts/movie2movie.py", "scripts/movie2movie.py",
"scripts/preprocessor/legacy/preprocessor_compiled.py",
"scripts/preprocessor/__init__.py",
] ]
ignore = [ ignore = [

View File

@ -15,9 +15,9 @@ from modules.api.models import * # noqa:F403
from modules.api import api from modules.api import api
from scripts import external_code, global_state from scripts import external_code, global_state
from scripts.processor import preprocessor_filters
from scripts.logging import logger from scripts.logging import logger
from scripts.external_code import ControlNetUnit from scripts.external_code import ControlNetUnit
from scripts.supported_preprocessor import Preprocessor
from annotator.openpose import draw_poses, decode_json_as_poses from annotator.openpose import draw_poses, decode_json_as_poses
from annotator.openpose.animalpose import draw_animalposes from annotator.openpose.animalpose import draw_animalposes
@ -87,7 +87,7 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
control_type: format_control_type( control_type: format_control_type(
*global_state.select_control_type(control_type) *global_state.select_control_type(control_type)
) )
for control_type in preprocessor_filters.keys() for control_type in Preprocessor.get_all_preprocessor_tags()
} }
} }
@ -96,10 +96,6 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
max_models_num = external_code.get_max_models_num() max_models_num = external_code.get_max_models_num()
return {"control_net_unit_count": max_models_num} return {"control_net_unit_count": max_models_num}
cached_cn_preprocessors = global_state.cache_preprocessors(
global_state.cn_preprocessor_modules
)
@app.post("/controlnet/detect") @app.post("/controlnet/detect")
async def detect( async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"), controlnet_module: str = Body("none", title="Controlnet Module"),
@ -111,14 +107,17 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
controlnet_threshold_b: float = Body(-1, title="Controlnet Threshold b"), controlnet_threshold_b: float = Body(-1, title="Controlnet Threshold b"),
low_vram: bool = Body(False, title="Low vram"), low_vram: bool = Body(False, title="Low vram"),
): ):
controlnet_module = global_state.reverse_preprocessor_aliases.get( preprocessor = Preprocessor.get_preprocessor(controlnet_module)
controlnet_module, controlnet_module
)
if controlnet_module not in cached_cn_preprocessors: if preprocessor is None:
raise HTTPException(status_code=422, detail="Module not available") raise HTTPException(status_code=422, detail="Module not available")
if controlnet_module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"): if controlnet_module in (
"clip_vision",
"revision_clipvision",
"revision_ignore_prompt",
"ip-adapter-auto",
):
raise HTTPException(status_code=422, detail="Module not supported") raise HTTPException(status_code=422, detail="Module not supported")
if len(controlnet_input_images) == 0: if len(controlnet_input_images) == 0:
@ -129,7 +128,7 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
) )
unit = ControlNetUnit( unit = ControlNetUnit(
module=controlnet_module, module=preprocessor.label,
processor_res=controlnet_processor_res, processor_res=controlnet_processor_res,
threshold_a=controlnet_threshold_a, threshold_a=controlnet_threshold_a,
threshold_b=controlnet_threshold_b, threshold_b=controlnet_threshold_b,
@ -139,8 +138,6 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
results = [] results = []
poses = [] poses = []
processor_module = cached_cn_preprocessors[controlnet_module]
for input_image in controlnet_input_images: for input_image in controlnet_input_images:
img = external_code.to_base64_nparray(input_image) img = external_code.to_base64_nparray(input_image)
@ -152,11 +149,11 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
self.value = json_dict self.value = json_dict
json_acceptor = JsonAcceptor() json_acceptor = JsonAcceptor()
detected_map, is_image = processor_module( detected_map = preprocessor.cached_call(
img, img,
res=unit.processor_res, resolution=unit.processor_res,
thr_a=unit.threshold_a, slider_1=unit.threshold_a,
thr_b=unit.threshold_b, slider_2=unit.threshold_b,
json_pose_callback=json_acceptor.accept, json_pose_callback=json_acceptor.accept,
low_vram=low_vram, low_vram=low_vram,
) )
@ -166,9 +163,8 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
assert json_acceptor.value is not None assert json_acceptor.value is not None
poses.append(json_acceptor.value) poses.append(json_acceptor.value)
global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
res = {"info": "Success"} res = {"info": "Success"}
if is_image: if preprocessor.returns_image:
res["images"] = [encode_to_base64(r) for r in results] res["images"] = [encode_to_base64(r) for r in results]
if poses: if poses:
res["poses"] = poses res["poses"] = poses
@ -176,7 +172,6 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
res["tensor"] = [encode_tensor_to_base64(r) for r in results] res["tensor"] = [encode_tensor_to_base64(r) for r in results]
return res return res
class Person(BaseModel): class Person(BaseModel):
pose_keypoints_2d: List[float] pose_keypoints_2d: List[float]
hand_right_keypoints_2d: Optional[List[float]] hand_right_keypoints_2d: Optional[List[float]]

View File

@ -12,18 +12,21 @@ import gradio as gr
import time import time
from einops import rearrange from einops import rearrange
# Register all preprocessors.
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.processor import HWC3
from scripts.controlnet_lllite import clear_all_lllite from scripts.controlnet_lllite import clear_all_lllite
from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter
from scripts.ipadapter.presets import IPAdapterPreset
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.photopea import Photopea from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger from scripts.logging import logger
from scripts.supported_preprocessor import Preprocessor
from scripts.animate_diff.batch import add_animate_diff_batch_input from scripts.animate_diff.batch import add_animate_diff_batch_input
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing
from modules.images import save_image from modules.images import save_image
@ -35,7 +38,6 @@ import torch
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from scripts.lvminthin import lvmin_thin, nake_nms from scripts.lvminthin import lvmin_thin, nake_nms
from scripts.processor import model_free_preprocessors
from scripts.controlnet_model_guess import build_model_by_guess, ControlModel from scripts.controlnet_model_guess import build_model_by_guess, ControlModel
from scripts.hook import torch_dfs from scripts.hook import torch_dfs
@ -220,7 +222,7 @@ def get_control(
unit: external_code.ControlNetUnit, unit: external_code.ControlNetUnit,
idx: int, idx: int,
control_model_type: ControlModelType, control_model_type: ControlModelType,
preprocessor, preprocessor: Preprocessor,
): ):
"""Get input for a ControlNet unit.""" """Get input for a ControlNet unit."""
if unit.is_animate_diff_batch: if unit.is_animate_diff_batch:
@ -264,16 +266,18 @@ def get_control(
def preprocess_input_image(input_image: np.ndarray): def preprocess_input_image(input_image: np.ndarray):
""" Preprocess single input image. """ """ Preprocess single input image. """
detected_map, is_image = preprocessor( detected_map = preprocessor.cached_call(
input_image, input_image,
res=unit.processor_res, resolution=unit.processor_res,
thr_a=unit.threshold_a, slider_1=unit.threshold_a,
thr_b=unit.threshold_b, slider_2=unit.threshold_b,
low_vram=( low_vram=(
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and ("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
shared.opts.data.get("controlnet_clip_detector_on_cpu", False) shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
), ),
model=unit.model,
) )
is_image = preprocessor.returns_image
if high_res_fix: if high_res_fix:
if is_image: if is_image:
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
@ -320,8 +324,6 @@ class Script(scripts.Script, metaclass=(
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.latest_network = None self.latest_network = None
self.preprocessor = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
self.unloadable = global_state.cn_preprocessor_unloadable
self.input_image = None self.input_image = None
self.latest_model_hash = "" self.latest_model_hash = ""
self.enabled_units: List[external_code.ControlNetUnit] = [] self.enabled_units: List[external_code.ControlNetUnit] = []
@ -353,7 +355,6 @@ class Script(scripts.Script, metaclass=(
group = ControlNetUiGroup( group = ControlNetUiGroup(
is_img2img, is_img2img,
Script.get_default_ui_unit(), Script.get_default_ui_unit(),
self.preprocessor,
photopea, photopea,
) )
return group, group.render(tabname, elem_id_tabname) return group, group.render(tabname, elem_id_tabname)
@ -664,11 +665,6 @@ class Script(scripts.Script, metaclass=(
if not local_unit.enabled: if not local_unit.enabled:
continue continue
# Consolidate meta preprocessors.
if local_unit.module == "ip-adapter-auto":
local_unit.module = IPAdapterPreset.match_model(local_unit.model).module
logger.info(f"ip-adapter-auto => {local_unit.module}")
if hasattr(local_unit, "unfold_merged"): if hasattr(local_unit, "unfold_merged"):
enabled_units.extend(local_unit.unfold_merged()) enabled_units.extend(local_unit.unfold_merged())
else: else:
@ -938,15 +934,6 @@ class Script(scripts.Script, metaclass=(
if self.latest_model_hash != p.sd_model.sd_model_hash: if self.latest_model_hash != p.sd_model.sd_model_hash:
Script.clear_control_model_cache() Script.clear_control_model_cache()
for idx, unit in enumerate(self.enabled_units):
unit.module = global_state.get_module_basename(unit.module)
# unload unused preproc
module_list = [unit.module for unit in self.enabled_units]
for key in self.unloadable:
if key not in module_list:
self.unloadable.get(key, lambda:None)()
self.latest_model_hash = p.sd_model.sd_model_hash self.latest_model_hash = p.sd_model.sd_model_hash
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
@ -961,7 +948,10 @@ class Script(scripts.Script, metaclass=(
logger.warning('A1111 inpaint and ControlNet inpaint duplicated. Falls back to inpaint_global_harmonious.') logger.warning('A1111 inpaint and ControlNet inpaint duplicated. Falls back to inpaint_global_harmonious.')
unit.module = 'inpaint' unit.module = 'inpaint'
if unit.module in model_free_preprocessors: preprocessor = Preprocessor.get_preprocessor(unit.module)
assert preprocessor is not None
if preprocessor.do_not_need_model:
model_net = None model_net = None
if 'reference' in unit.module: if 'reference' in unit.module:
control_model_type = ControlModelType.AttentionInjection control_model_type = ControlModelType.AttentionInjection
@ -990,7 +980,7 @@ class Script(scripts.Script, metaclass=(
hr_controls = unit.ipadapter_input hr_controls = unit.ipadapter_input
else: else:
controls, hr_controls, additional_maps = get_control( controls, hr_controls, additional_maps = get_control(
p, unit, idx, control_model_type, self.preprocessor[unit.module]) p, unit, idx, control_model_type, preprocessor)
detected_maps.extend(additional_maps) detected_maps.extend(additional_maps)
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]: if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:

View File

@ -2,23 +2,17 @@ import json
import gradio as gr import gradio as gr
import functools import functools
from copy import copy from copy import copy
from typing import List, Optional, Union, Callable, Dict, Tuple, Literal from typing import List, Optional, Union, Dict, Tuple, Literal
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from scripts.supported_preprocessor import Preprocessor
from scripts.utils import svg_preprocess, read_image from scripts.utils import svg_preprocess, read_image
from scripts import ( from scripts import (
global_state, global_state,
external_code, external_code,
) )
from scripts.processor import ( from annotator.util import HWC3
preprocessor_sliders_config,
no_control_mode_preprocessors,
flag_preprocessor_resolution,
model_free_preprocessors,
preprocessor_filters,
HWC3,
)
from scripts.logging import logger from scripts.logging import logger
from scripts.controlnet_ui.openpose_editor import OpenposeEditor from scripts.controlnet_ui.openpose_editor import OpenposeEditor
from scripts.controlnet_ui.preset import ControlNetPresetUI from scripts.controlnet_ui.preset import ControlNetPresetUI
@ -227,7 +221,6 @@ class ControlNetUiGroup(object):
self, self,
is_img2img: bool, is_img2img: bool,
default_unit: external_code.ControlNetUnit, default_unit: external_code.ControlNetUnit,
preprocessors: List[Callable],
photopea: Optional[Photopea], photopea: Optional[Photopea],
): ):
# Whether callbacks have been registered. # Whether callbacks have been registered.
@ -237,7 +230,6 @@ class ControlNetUiGroup(object):
self.is_img2img = is_img2img self.is_img2img = is_img2img
self.default_unit = default_unit self.default_unit = default_unit
self.preprocessors = preprocessors
self.photopea = photopea self.photopea = photopea
self.webcam_enabled = False self.webcam_enabled = False
self.webcam_mirrored = False self.webcam_mirrored = False
@ -300,10 +292,6 @@ class ControlNetUiGroup(object):
self.batch_image_dir_state = None self.batch_image_dir_state = None
self.output_dir_state = None self.output_dir_state = None
# Internal states for UI state pasting.
self.prevent_next_n_module_update = 0
self.prevent_next_n_slider_value_update = 0
# API-only fields # API-only fields
self.advanced_weighting = gr.State(None) self.advanced_weighting = gr.State(None)
self.ipadapter_input = gr.State(None) self.ipadapter_input = gr.State(None)
@ -526,7 +514,7 @@ class ControlNetUiGroup(object):
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
self.type_filter = gr.Radio( self.type_filter = gr.Radio(
list(preprocessor_filters.keys()), Preprocessor.get_all_preprocessor_tags(),
label="Control Type", label="Control Type",
value="All", value="All",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio",
@ -535,7 +523,7 @@ class ControlNetUiGroup(object):
with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]): with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]):
self.module = gr.Dropdown( self.module = gr.Dropdown(
global_state.ui_preprocessor_keys, [p.label for p in Preprocessor.get_sorted_preprocessors()],
label="Preprocessor", label="Preprocessor",
value=self.default_unit.module, value=self.default_unit.module,
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown",
@ -798,82 +786,21 @@ class ControlNetUiGroup(object):
def register_build_sliders(self): def register_build_sliders(self):
def build_sliders(module: str, pp: bool): def build_sliders(module: str, pp: bool):
logger.debug( preprocessor = Preprocessor.get_preprocessor(module)
f"Prevent update slider value: {self.prevent_next_n_slider_value_update}" slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy()
)
logger.debug(f"Build slider for module: {module} - {pp}")
# Clear old slider values so that they do not cause confusion in if pp:
# infotext. slider_resolution_kwargs['visible'] = False
clear_slider_update = gr.update(
visible=False,
interactive=True,
minimum=-1,
maximum=-1,
value=-1,
)
grs = [] grs = [
module = global_state.get_module_basename(module) gr.update(**slider_resolution_kwargs),
if module not in preprocessor_sliders_config: gr.update(**preprocessor.slider_1.gradio_update_kwargs.copy()),
default_res_slider_config = dict( gr.update(**preprocessor.slider_2.gradio_update_kwargs.copy()),
label=flag_preprocessor_resolution,
minimum=64,
maximum=2048,
step=1,
)
if self.prevent_next_n_slider_value_update == 0:
default_res_slider_config["value"] = 512
grs += [
gr.update(
**default_res_slider_config,
visible=not pp,
interactive=True,
),
copy(clear_slider_update),
copy(clear_slider_update),
gr.update(visible=True), gr.update(visible=True),
gr.update(visible=not preprocessor.do_not_need_model),
gr.update(visible=not preprocessor.do_not_need_model),
gr.update(visible=preprocessor.show_control_mode),
] ]
else:
for slider_config in preprocessor_sliders_config[module]:
if isinstance(slider_config, dict):
visible = True
if slider_config["name"] == flag_preprocessor_resolution:
visible = not pp
slider_update = gr.update(
label=slider_config["name"],
minimum=slider_config["min"],
maximum=slider_config["max"],
step=slider_config["step"]
if "step" in slider_config
else 1,
visible=visible,
interactive=True,
)
if self.prevent_next_n_slider_value_update == 0:
slider_update["value"] = slider_config["value"]
grs.append(slider_update)
else:
grs.append(copy(clear_slider_update))
while len(grs) < 3:
grs.append(copy(clear_slider_update))
grs.append(gr.update(visible=True))
if module in model_free_preprocessors:
grs += [
gr.update(visible=False, value="None"),
gr.update(visible=False),
]
else:
grs += [gr.update(visible=True), gr.update(visible=True)]
self.prevent_next_n_slider_value_update = max(
0, self.prevent_next_n_slider_value_update - 1
)
grs += [gr.update(visible=module not in no_control_mode_preprocessors)]
return grs return grs
@ -898,7 +825,6 @@ class ControlNetUiGroup(object):
) )
def filter_selected(k: str): def filter_selected(k: str):
logger.debug(f"Prevent update {self.prevent_next_n_module_update}")
logger.debug(f"Switch to control type {k}") logger.debug(f"Switch to control type {k}")
( (
filtered_preprocessor_list, filtered_preprocessor_list,
@ -906,14 +832,6 @@ class ControlNetUiGroup(object):
default_option, default_option,
default_model, default_model,
) = global_state.select_control_type(k, global_state.get_sd_version()) ) = global_state.select_control_type(k, global_state.get_sd_version())
if self.prevent_next_n_module_update > 0:
self.prevent_next_n_module_update -= 1
return [
gr.Dropdown.update(choices=filtered_preprocessor_list),
gr.Dropdown.update(choices=filtered_model_list),
]
else:
return [ return [
gr.Dropdown.update( gr.Dropdown.update(
value=default_option, choices=filtered_preprocessor_list value=default_option, choices=filtered_preprocessor_list
@ -959,7 +877,7 @@ class ControlNetUiGroup(object):
) )
def register_run_annotator(self): def register_run_annotator(self):
def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm): def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str):
if image is None: if image is None:
return ( return (
gr.update(value=None, visible=True), gr.update(value=None, visible=True),
@ -981,8 +899,7 @@ class ControlNetUiGroup(object):
): ):
img = HWC3(image["mask"][:, :, 0]) img = HWC3(image["mask"][:, :, 0])
module = global_state.get_module_basename(module) preprocessor = Preprocessor.get_preprocessor(module)
preprocessor = self.preprocessors[module]
if pp: if pp:
pres = external_code.pixel_perfect_resolution( pres = external_code.pixel_perfect_resolution(
@ -1013,23 +930,25 @@ class ControlNetUiGroup(object):
# effect. # effect.
# TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue? # TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue?
# This requires changing all callsites though. # This requires changing all callsites though.
result, is_image = preprocessor( result = preprocessor.cached_call(
img, img,
res=pres, resolution=pres,
thr_a=pthr_a, slider_1=pthr_a,
thr_b=pthr_b, slider_2=pthr_b,
low_vram=( low_vram=(
("clip" in module or module == "ip-adapter_face_id_plus") ("clip" in module or module == "ip-adapter_face_id_plus")
and shared.opts.data.get("controlnet_clip_detector_on_cpu", False) and shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
), ),
json_pose_callback=json_acceptor.accept json_pose_callback=(
json_acceptor.accept
if is_openpose(module) if is_openpose(module)
else None, else None
),
model=model,
) )
if not is_image: if not preprocessor.returns_image:
result = img result = img
is_image = True
result = external_code.visualize_inpaint_mask(result) result = external_code.visualize_inpaint_mask(result)
return ( return (
@ -1057,6 +976,7 @@ class ControlNetUiGroup(object):
else ControlNetUiGroup.a1111_context.txt2img_h_slider, else ControlNetUiGroup.a1111_context.txt2img_h_slider,
self.pixel_perfect, self.pixel_perfect,
self.resize_mode, self.resize_mode,
self.model,
], ],
outputs=[ outputs=[
self.generated_image, self.generated_image,

View File

@ -7,8 +7,8 @@ from modules import scripts
from scripts.infotext import parse_unit, serialize_unit from scripts.infotext import parse_unit, serialize_unit
from scripts.controlnet_ui.tool_button import ToolButton from scripts.controlnet_ui.tool_button import ToolButton
from scripts.logging import logger from scripts.logging import logger
from scripts.processor import preprocessor_filters
from scripts import external_code from scripts import external_code
from scripts.supported_preprocessor import Preprocessor
save_symbol = "\U0001f4be" # 💾 save_symbol = "\U0001f4be" # 💾
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️ delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
@ -38,7 +38,7 @@ def infer_control_type(module: str, model: str) -> str:
def matches_control_type(input_string: str, control_type: str) -> bool: def matches_control_type(input_string: str, control_type: str) -> bool:
return any(t.lower() in input_string for t in control_type.split("/")) return any(t.lower() in input_string for t in control_type.split("/"))
control_types = preprocessor_filters.keys() control_types = Preprocessor.get_all_preprocessor_tags()
control_type_candidates = [ control_type_candidates = [
control_type control_type
for control_type in control_types for control_type in control_types

View File

@ -1,17 +1,14 @@
import os.path import os.path
import stat import stat
import functools
from collections import OrderedDict from collections import OrderedDict
from modules import shared, scripts, sd_models from modules import shared, scripts, sd_models
from modules.paths import models_path from modules.paths import models_path
from scripts.processor import * # noqa: E403
import scripts.processor as processor
from scripts.utils import ndarray_lru_cache
from scripts.logging import logger
from scripts.enums import StableDiffusionVersion
from typing import Dict, Callable, Optional, Tuple, List from scripts.enums import StableDiffusionVersion
from scripts.supported_preprocessor import Preprocessor
from typing import Dict, Tuple, List
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
cn_models_dir = os.path.join(models_path, "ControlNet") cn_models_dir = os.path.join(models_path, "ControlNet")
@ -19,187 +16,6 @@ cn_models_dir_old = os.path.join(scripts.basedir(), "models")
cn_models = OrderedDict() # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors cn_models = OrderedDict() # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors
cn_models_names = {} # "my_lora" -> "My_Lora(abcd1234)" cn_models_names = {} # "my_lora" -> "My_Lora(abcd1234)"
def cache_preprocessors(preprocessor_modules: Dict[str, Callable]) -> Dict[str, Callable]:
""" We want to share the preprocessor results in a single big cache, instead of a small
cache for each preprocessor function. """
CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0)
# Set CACHE_SIZE = 0 will completely remove the caching layer. This can be
# helpful when debugging preprocessor code.
if CACHE_SIZE == 0:
return preprocessor_modules
logger.debug(f'Create LRU cache (max_size={CACHE_SIZE}) for preprocessor results.')
@ndarray_lru_cache(max_size=CACHE_SIZE)
def unified_preprocessor(preprocessor_name: str, *args, **kwargs):
logger.debug(f'Calling preprocessor {preprocessor_name} outside of cache.')
return preprocessor_modules[preprocessor_name](*args, **kwargs)
# TODO: Introduce a seed parameter for shuffle preprocessor?
uncacheable_preprocessors = ['shuffle']
return {
k: (
v if k in uncacheable_preprocessors
else functools.partial(unified_preprocessor, k)
)
for k, v
in preprocessor_modules.items()
}
cn_preprocessor_modules = {
"none": lambda x, *args, **kwargs: (x, True),
"canny": canny,
"depth": midas,
"depth_leres": functools.partial(leres, boost=False),
"depth_leres++": functools.partial(leres, boost=True),
"depth_hand_refiner": g_hand_refiner_model.run_model,
"depth_anything": functools.partial(depth_anything, colored=False),
"hed": hed,
"hed_safe": hed_safe,
"mediapipe_face": mediapipe_face,
"mlsd": mlsd,
"normal_map": midas_normal,
"openpose": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=False, include_face=False),
"openpose_hand": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=True, include_face=False),
"openpose_face": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=False, include_face=True),
"openpose_faceonly": functools.partial(g_openpose_model.run_model, include_body=False, include_hand=False, include_face=True),
"openpose_full": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=True, include_face=True),
"dw_openpose_full": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=True, include_face=True, use_dw_pose=True),
"animal_openpose": functools.partial(g_openpose_model.run_model, include_body=True, include_hand=False, include_face=False, use_animal_pose=True),
"clip_vision": functools.partial(clip, config='clip_vitl'),
"revision_clipvision": functools.partial(clip, config='clip_g'),
"revision_ignore_prompt": functools.partial(clip, config='clip_g'),
"ip-adapter_clip_sd15": functools.partial(clip, config='clip_h'),
"ip-adapter_clip_sdxl_plus_vith": functools.partial(clip, config='clip_h'),
"ip-adapter_clip_sdxl": functools.partial(clip, config='clip_g'),
"ip-adapter_face_id": g_insight_face_model.run_model,
"ip-adapter_face_id_plus": face_id_plus,
"instant_id_face_keypoints": functools.partial(g_insight_face_instant_id_model.run_model_instant_id, return_keypoints=True),
"instant_id_face_embedding": functools.partial(g_insight_face_instant_id_model.run_model_instant_id, return_keypoints=False),
"color": color,
"pidinet": pidinet,
"pidinet_safe": pidinet_safe,
"pidinet_sketch": pidinet_ts,
"pidinet_scribble": scribble_pidinet,
"scribble_xdog": scribble_xdog,
"scribble_hed": scribble_hed,
"segmentation": uniformer,
"threshold": threshold,
"depth_zoe": zoe_depth,
"normal_bae": normal_bae,
"oneformer_coco": oneformer_coco,
"oneformer_ade20k": oneformer_ade20k,
"lineart": lineart,
"lineart_coarse": lineart_coarse,
"lineart_anime": lineart_anime,
"lineart_standard": lineart_standard,
"shuffle": shuffle,
"tile_resample": tile_resample,
"invert": invert,
"lineart_anime_denoise": lineart_anime_denoise,
"reference_only": identity,
"reference_adain": identity,
"reference_adain+attn": identity,
"inpaint": identity,
"inpaint_only": identity,
"inpaint_only+lama": lama_inpaint,
"tile_colorfix": identity,
"tile_colorfix+sharp": identity,
"recolor_luminance": recolor_luminance,
"recolor_intensity": recolor_intensity,
"blur_gaussian": blur_gaussian,
"anime_face_segment": anime_face_segment,
"densepose": functools.partial(densepose, cmap="viridis"),
"densepose_parula": functools.partial(densepose, cmap="parula"),
"te_hed":te_hed,
"normal_dsine": normal_dsine,
}
cn_preprocessor_unloadable = {
"hed": unload_hed,
"fake_scribble": unload_hed,
"mlsd": unload_mlsd,
"clip_vision": functools.partial(unload_clip, config='clip_vitl'),
"revision_clipvision": functools.partial(unload_clip, config='clip_g'),
"revision_ignore_prompt": functools.partial(unload_clip, config='clip_g'),
"ip-adapter_clip_sd15": functools.partial(unload_clip, config='clip_h'),
"ip-adapter_clip_sdxl_plus_vith": functools.partial(unload_clip, config='clip_h'),
"ip-adapter_face_id_plus": functools.partial(unload_clip, config='clip_h'),
"ip-adapter_clip_sdxl": functools.partial(unload_clip, config='clip_g'),
"depth": unload_midas,
"depth_leres": unload_leres,
"depth_anything": unload_depth_anything,
"normal_map": unload_midas,
"pidinet": unload_pidinet,
"openpose": g_openpose_model.unload,
"openpose_hand": g_openpose_model.unload,
"openpose_face": g_openpose_model.unload,
"openpose_full": g_openpose_model.unload,
"dw_openpose_full": g_openpose_model.unload,
"animal_openpose": g_openpose_model.unload,
"segmentation": unload_uniformer,
"depth_zoe": unload_zoe_depth,
"normal_bae": unload_normal_bae,
"oneformer_coco": unload_oneformer_coco,
"oneformer_ade20k": unload_oneformer_ade20k,
"lineart": unload_lineart,
"lineart_coarse": unload_lineart_coarse,
"lineart_anime": unload_lineart_anime,
"lineart_anime_denoise": unload_lineart_anime_denoise,
"inpaint_only+lama": unload_lama_inpaint,
"anime_face_segment": unload_anime_face_segment,
"densepose": unload_densepose,
"densepose_parula": unload_densepose,
"depth_hand_refiner": g_hand_refiner_model.unload,
"te_hed":unload_te_hed,
"normal_dsine": unload_normal_dsine,
}
preprocessor_aliases = {
"invert": "invert (from white bg & black line)",
"lineart_standard": "lineart_standard (from white bg & black line)",
"lineart": "lineart_realistic",
"color": "t2ia_color_grid",
"clip_vision": "t2ia_style_clipvision",
"pidinet_sketch": "t2ia_sketch_pidi",
"depth": "depth_midas",
"normal_map": "normal_midas",
"hed": "softedge_hed",
"hed_safe": "softedge_hedsafe",
"pidinet": "softedge_pidinet",
"pidinet_safe": "softedge_pidisafe",
"segmentation": "seg_ufade20k",
"oneformer_coco": "seg_ofcoco",
"oneformer_ade20k": "seg_ofade20k",
"pidinet_scribble": "scribble_pidinet",
"inpaint": "inpaint_global_harmonious",
"anime_face_segment": "seg_anime_face",
"densepose": "densepose (pruple bg & purple torso)",
"densepose_parula": "densepose_parula (black bg & blue torso)",
"te_hed": "softedge_teed",
"ip-adapter_clip_sd15": "ip-adapter_clip_h",
"ip-adapter_clip_sdxl": "ip-adapter_clip_g",
}
# Preprocessor that automatically maps to other preprocessors.
meta_preprocessors = ["ip-adapter-auto"]
ui_preprocessor_keys = ['none', preprocessor_aliases['invert']]
ui_preprocessor_keys += meta_preprocessors
ui_preprocessor_keys += sorted([preprocessor_aliases.get(k, k)
for k in cn_preprocessor_modules.keys()
if preprocessor_aliases.get(k, k) not in ui_preprocessor_keys])
reverse_preprocessor_aliases = {preprocessor_aliases[k]: k for k in preprocessor_aliases.keys()}
def get_module_basename(module: Optional[str]) -> str:
if module is None:
module = 'none'
return reverse_preprocessor_aliases.get(module, module)
default_detectedmap_dir = os.path.join("detected_maps") default_detectedmap_dir = os.path.join("detected_maps")
script_dir = scripts.basedir() script_dir = scripts.basedir()
@ -300,50 +116,28 @@ def select_control_type(
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
cn_models: Dict = cn_models, # Override or testing cn_models: Dict = cn_models, # Override or testing
) -> Tuple[List[str], List[str], str, str]: ) -> Tuple[List[str], List[str], str, str]:
default_option = processor.preprocessor_filters[control_type]
pattern = control_type.lower() pattern = control_type.lower()
preprocessor_list = ui_preprocessor_keys
all_models = list(cn_models.keys()) all_models = list(cn_models.keys())
if pattern == "all": if pattern == "all":
return [ return [
preprocessor_list, [p.label for p in Preprocessor.get_sorted_preprocessors()],
all_models, all_models,
'none', #default option 'none', #default option
"None" #default model "None" #default model
] ]
filtered_preprocessor_list = [
x
for x in preprocessor_list
if ((
pattern in x.lower() or
any(a in x.lower() for a in processor.preprocessor_filters_aliases.get(pattern, [])) or
x.lower() == "none"
) and (
sd_version.is_compatible_with(StableDiffusionVersion.detect_from_model_name(x))
))
]
if pattern in ["canny", "lineart", "scribble/sketch", "mlsd"]:
filtered_preprocessor_list += [
x for x in preprocessor_list if "invert" in x.lower()
]
if pattern in ["sparsectrl"]:
filtered_preprocessor_list += [
x for x in preprocessor_list if "scribble" in x.lower()
]
filtered_model_list = [ filtered_model_list = [
model for model in all_models model for model in all_models
if model.lower() == "none" or if model.lower() == "none" or
(( ((
pattern in model.lower() or pattern in model.lower() or
any(a in model.lower() for a in processor.preprocessor_filters_aliases.get(pattern, [])) any(a in model.lower() for a in Preprocessor.tag_to_filters(control_type))
) and ( ) and (
sd_version.is_compatible_with(StableDiffusionVersion.detect_from_model_name(model)) sd_version.is_compatible_with(StableDiffusionVersion.detect_from_model_name(model))
)) ))
] ]
assert len(filtered_model_list) > 0, "'None' model should always be available." assert len(filtered_model_list) > 0, "'None' model should always be available."
if default_option not in filtered_preprocessor_list:
default_option = filtered_preprocessor_list[0]
if len(filtered_model_list) == 1: if len(filtered_model_list) == 1:
default_model = "None" default_model = "None"
else: else:
@ -354,8 +148,10 @@ def select_control_type(
break break
return ( return (
filtered_preprocessor_list, [p.label for p in Preprocessor.get_filtered_preprocessors(control_type)],
filtered_model_list, filtered_model_list,
default_option, Preprocessor.get_default_preprocessor(control_type).label,
default_model default_model
) )

View File

@ -0,0 +1,4 @@
from .ip_adapter_auto import *
from .normal_dsine import *
from .model_free_preprocessors import *
from .legacy.legacy_preprocessors import *

View File

@ -0,0 +1,25 @@
from ..ipadapter.presets import IPAdapterPreset
from ..supported_preprocessor import Preprocessor
from ..logging import logger
class PreprocessorIPAdapterAuto(Preprocessor):
def __init__(self):
super().__init__(name="ip-adapter-auto")
self.tags = ["IP-Adapter"]
self.sorting_priority = 1000
self.returns_image = False
self.show_control_mode = False
def __call__(self, *args, **kwargs):
assert "model" in kwargs
model: str = kwargs["model"]
module: str = IPAdapterPreset.match_model(model).module
logger.info(f"ip-adapter-auto => {module}")
p = Preprocessor.get_preprocessor(module)
assert p is not None
return p(*args, **kwargs)
Preprocessor.add_supported_preprocessor(PreprocessorIPAdapterAuto())

View File

@ -0,0 +1,112 @@
# This is a python script to convert all old preprocessors to new format.
# However, the old preprocessors are not very memory effective
# and eventually we should move all old preprocessors to new format manually
# see also the forge_preprocessor_normalbae/scripts/preprocessor_normalbae for
# how to make better implementation of preprocessors.
# No newer preprocessors should be written in this legacy way.
# Never add new leagcy preprocessors please.
# The new forge_preprocessor_normalbae/scripts/preprocessor_normalbae
# is much more effective and maintainable
from annotator.util import HWC3
from .preprocessor_compiled import legacy_preprocessors
from ...supported_preprocessor import Preprocessor, PreprocessorParameter
###
# This file has lots of unreasonable historical designs and should be viewed as a frozen blackbox library.
# If you want to add preprocessor,
# please instead look at `extensions-builtin/forge_preprocessor_normalbae/scripts/preprocessor_normalbae`
# If you want to use preprocessor,
# please instead use `from modules_forge.shared import supported_preprocessors`
# and then use any preprocessor like: depth_midas = supported_preprocessors['depth_midas']
# Please do not hack/edit/modify/rely-on any codes in this file.
# Never use methods in this file to add anything!
# This file will be eventually removed but the workload is super high and we need more time to do this.
###
class LegacyPreprocessor(Preprocessor):
def __init__(self, name: str, legacy_dict):
super().__init__(name)
self._label = legacy_dict["label"]
self.call_function = legacy_dict["call_function"]
self.unload_function = legacy_dict["unload_function"]
self.managed_model = legacy_dict["managed_model"]
self.do_not_need_model = legacy_dict["model_free"]
self.show_control_mode = not legacy_dict["no_control_mode"]
self.sorting_priority = legacy_dict["priority"]
self.tags = legacy_dict["tags"]
self.returns_image = legacy_dict.get("returns_image", True)
if legacy_dict.get("use_soft_projection_in_hr_fix", False):
self.use_soft_projection_in_hr_fix = True
if legacy_dict["resolution"] is None:
self.resolution = PreprocessorParameter(visible=False)
else:
legacy_dict["resolution"]["label"] = "Resolution"
legacy_dict["resolution"]["step"] = 8
self.resolution = PreprocessorParameter(
**legacy_dict["resolution"], visible=True
)
if legacy_dict["slider_1"] is None:
self.slider_1 = PreprocessorParameter(visible=False)
else:
self.slider_1 = PreprocessorParameter(
**legacy_dict["slider_1"], visible=True
)
if legacy_dict["slider_2"] is None:
self.slider_2 = PreprocessorParameter(visible=False)
else:
self.slider_2 = PreprocessorParameter(
**legacy_dict["slider_2"], visible=True
)
if legacy_dict["slider_3"] is None:
self.slider_3 = PreprocessorParameter(visible=False)
else:
self.slider_3 = PreprocessorParameter(
**legacy_dict["slider_3"], visible=True
)
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
# Legacy Preprocessors does not have slider 3
del slider_3
if self.managed_model is not None:
assert self.unload_function is not None
result, is_image = self.call_function(
img=input_image, res=resolution, thr_a=slider_1, thr_b=slider_2, **kwargs
)
if is_image:
result = HWC3(result)
if self.unload_function is not None:
self.unload_function()
return result
for name, data in legacy_preprocessors.items():
p = LegacyPreprocessor(name, data)
Preprocessor.add_supported_preprocessor(p)

File diff suppressed because it is too large Load Diff

View File

@ -51,24 +51,9 @@ def resize_image_with_pad(input_image, resolution, skip_hwc3=False):
return safer_memory(img_padded), remove_pad return safer_memory(img_padded), remove_pad
model_canny = None
def canny(img, res=512, thr_a=100, thr_b=200, **kwargs): def canny(img, res=512, thr_a=100, thr_b=200, **kwargs):
l, h = thr_a, thr_b # noqa: E741
img, remove_pad = resize_image_with_pad(img, res) img, remove_pad = resize_image_with_pad(img, res)
global model_canny result = cv2.Canny(img, thr_a, thr_b)
if model_canny is None:
from annotator.canny import apply_canny
model_canny = apply_canny
result = model_canny(img, l, h)
return remove_pad(result), True
def scribble_thr(img, res=512, **kwargs):
img, remove_pad = resize_image_with_pad(img, res)
result = np.zeros_like(img, dtype=np.uint8)
result[np.min(img, axis=2) < 127] = 255
return remove_pad(result), True return remove_pad(result), True
@ -620,20 +605,6 @@ def unload_oneformer_ade20k():
model_oneformer_ade20k.unload_model() model_oneformer_ade20k.unload_model()
model_shuffle = None
def shuffle(img, res=512, **kwargs):
img, remove_pad = resize_image_with_pad(img, res)
img = remove_pad(img)
global model_shuffle
if model_shuffle is None:
from annotator.shuffle import ContentShuffleDetector
model_shuffle = ContentShuffleDetector()
result = model_shuffle(img)
return result, True
def recolor_luminance(img, res=512, thr_a=1.0, **kwargs): def recolor_luminance(img, res=512, thr_a=1.0, **kwargs):
result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2LAB) result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2LAB)
result = result[:, :, 0].astype(np.float32) / 255.0 result = result[:, :, 0].astype(np.float32) / 255.0
@ -706,25 +677,6 @@ def unload_te_hed():
if model_te_hed is not None: if model_te_hed is not None:
model_te_hed.unload_model() model_te_hed.unload_model()
model_normal_dsine = None
def normal_dsine(img, res=512, thr_a=60.0,thr_b=5, **kwargs):
global model_normal_dsine
if model_normal_dsine is None:
from annotator.normaldsine import NormalDsineDetector
model_normal_dsine = NormalDsineDetector()
result = model_normal_dsine(img, new_fov=float(thr_a), iterations=int(thr_b), resulotion=res)
return result, True
def unload_normal_dsine():
global model_normal_dsine
if model_normal_dsine is not None:
model_normal_dsine.unload_model()
class InsightFaceModel: class InsightFaceModel:
def __init__(self, face_analysis_model_name: str = "buffalo_l"): def __init__(self, face_analysis_model_name: str = "buffalo_l"):
self.model = None self.model = None
@ -775,7 +727,7 @@ class InsightFaceModel:
def run_model(self, img: np.ndarray, **kwargs) -> Tuple[torch.Tensor, bool]: def run_model(self, img: np.ndarray, **kwargs) -> Tuple[torch.Tensor, bool]:
self.load_model() self.load_model()
assert img.shape[2] == 3, f"Expect 3 channels, but get {img.shape} channels" img = img[:, :, :3] # Drop alpha channel if there is one.
faces = self.model.get(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) faces = self.model.get(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
face = InsightFaceModel.pick_largest_face(faces) face = InsightFaceModel.pick_largest_face(faces)
return torch.from_numpy(face.normed_embedding).unsqueeze(0), False return torch.from_numpy(face.normed_embedding).unsqueeze(0), False
@ -886,487 +838,3 @@ class HandRefinerModel:
g_hand_refiner_model = HandRefinerModel() g_hand_refiner_model = HandRefinerModel()
model_free_preprocessors = [
"reference_only",
"reference_adain",
"reference_adain+attn",
"revision_clipvision",
"revision_ignore_prompt"
]
no_control_mode_preprocessors = [
"revision_clipvision",
"revision_ignore_prompt",
"clip_vision",
"ip-adapter_clip_sd15",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"t2ia_style_clipvision",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
]
flag_preprocessor_resolution = "Preprocessor Resolution"
preprocessor_sliders_config = {
"none": [],
"inpaint": [],
"inpaint_only": [],
"revision_clipvision": [
None,
{
"name": "Noise Augmentation",
"value": 0.0,
"min": 0.0,
"max": 1.0
},
],
"revision_ignore_prompt": [
None,
{
"name": "Noise Augmentation",
"value": 0.0,
"min": 0.0,
"max": 1.0
},
],
"canny": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
},
{
"name": "Canny Low Threshold",
"value": 100,
"min": 1,
"max": 255
},
{
"name": "Canny High Threshold",
"value": 200,
"min": 1,
"max": 255
},
],
"mlsd": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
},
{
"name": "MLSD Value Threshold",
"min": 0.01,
"max": 2.0,
"value": 0.1,
"step": 0.01
},
{
"name": "MLSD Distance Threshold",
"min": 0.01,
"max": 20.0,
"value": 0.1,
"step": 0.01
}
],
"hed": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"scribble_hed": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"hed_safe": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"openpose": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"openpose_full": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"dw_openpose_full": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"animal_openpose": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"segmentation": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"depth": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"depth_leres": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
},
{
"name": "Remove Near %",
"min": 0,
"max": 100,
"value": 0,
"step": 0.1,
},
{
"name": "Remove Background %",
"min": 0,
"max": 100,
"value": 0,
"step": 0.1,
}
],
"depth_leres++": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
},
{
"name": "Remove Near %",
"min": 0,
"max": 100,
"value": 0,
"step": 0.1,
},
{
"name": "Remove Background %",
"min": 0,
"max": 100,
"value": 0,
"step": 0.1,
}
],
"normal_map": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
},
{
"name": "Normal Background Threshold",
"min": 0.0,
"max": 1.0,
"value": 0.4,
"step": 0.01
}
],
"threshold": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
},
{
"name": "Binarization Threshold",
"min": 0,
"max": 255,
"value": 127
}
],
"scribble_xdog": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
},
{
"name": "XDoG Threshold",
"min": 1,
"max": 64,
"value": 32,
}
],
"blur_gaussian": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
},
{
"name": "Sigma",
"min": 0.01,
"max": 64.0,
"value": 9.0,
}
],
"tile_resample": [
None,
{
"name": "Down Sampling Rate",
"value": 1.0,
"min": 1.0,
"max": 8.0,
"step": 0.01
}
],
"tile_colorfix": [
None,
{
"name": "Variation",
"value": 8.0,
"min": 3.0,
"max": 32.0,
"step": 1.0
}
],
"tile_colorfix+sharp": [
None,
{
"name": "Variation",
"value": 8.0,
"min": 3.0,
"max": 32.0,
"step": 1.0
},
{
"name": "Sharpness",
"value": 1.0,
"min": 0.0,
"max": 2.0,
"step": 0.01
}
],
"reference_only": [
None,
{
"name": r'Style Fidelity (only for "Balanced" mode)',
"value": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}
],
"reference_adain": [
None,
{
"name": r'Style Fidelity (only for "Balanced" mode)',
"value": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}
],
"reference_adain+attn": [
None,
{
"name": r'Style Fidelity (only for "Balanced" mode)',
"value": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}
],
"inpaint_only+lama": [],
"color": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048,
}
],
"mediapipe_face": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048,
},
{
"name": "Max Faces",
"value": 1,
"min": 1,
"max": 10,
"step": 1
},
{
"name": "Min Face Confidence",
"value": 0.5,
"min": 0.01,
"max": 1.0,
"step": 0.01
}
],
"recolor_luminance": [
None,
{
"name": "Gamma Correction",
"value": 1.0,
"min": 0.1,
"max": 2.0,
"step": 0.001
}
],
"recolor_intensity": [
None,
{
"name": "Gamma Correction",
"value": 1.0,
"min": 0.1,
"max": 2.0,
"step": 0.001
}
],
"anime_face_segment": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
}
],
"densepose": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"densepose_parula": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
}
],
"depth_hand_refiner": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
}
],
"te_hed": [
{
"name": flag_preprocessor_resolution,
"value": 512,
"min": 64,
"max": 2048
},
{
"name": "Safe Steps",
"min": 0,
"max": 10,
"value": 2,
"step": 1,
},
],
"normal_dsine": [
{
"name": flag_preprocessor_resolution,
"min": 64,
"max": 2048,
"value": 512
},
{
"name": "Fov",
"min": 0.0,
"max": 360.0,
"value": 60.0,
"step": 0.1,
},
{
"name": "Iterations",
"min": 1,
"max": 20,
"value": 5,
"step": 1,
},
],
}
preprocessor_filters = {
"All": "none",
"Canny": "canny",
"Depth": "depth_midas",
"NormalMap": "normal_bae",
"OpenPose": "openpose_full",
"MLSD": "mlsd",
"Lineart": "lineart_standard (from white bg & black line)",
"SoftEdge": "softedge_pidinet",
"Scribble/Sketch": "scribble_pidinet",
"Segmentation": "seg_ofade20k",
"Shuffle": "shuffle",
"Tile/Blur": "tile_resample",
"Inpaint": "inpaint_only",
"InstructP2P": "none",
"Reference": "reference_only",
"Recolor": "recolor_luminance",
"Revision": "revision_clipvision",
"T2I-Adapter": "none",
"IP-Adapter": "ip-adapter-auto",
"Instant_ID": "instant_id",
"SparseCtrl": "none",
}
preprocessor_filters_aliases = {
'instructp2p': ['ip2p'],
'segmentation': ['seg'],
'normalmap': ['normal'],
't2i-adapter': ['t2i_adapter', 't2iadapter', 't2ia'],
'ip-adapter': ['ip_adapter', 'ipadapter'],
'scribble/sketch': ['scribble', 'sketch'],
'tile/blur': ['tile', 'blur'],
'openpose':['openpose', 'densepose'],
} # must use all lower texts

View File

@ -0,0 +1,184 @@
"""Preprocessors that do not need to run a torch model."""
import cv2
import numpy as np
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
from ..utils import resize_image_with_pad
from annotator.util import HWC3
class PreprocessorNone(Preprocessor):
def __init__(self):
super().__init__(name="none")
self.sorting_priority = 10
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs
):
return input_image
class PreprocessorCanny(Preprocessor):
def __init__(self):
super().__init__(name="canny")
self.tags = ["Canny"]
self.slider_1 = PreprocessorParameter(
minimum=1,
maximum=255,
step=1,
value=100,
label="Low Threshold",
)
self.slider_2 = PreprocessorParameter(
minimum=1,
maximum=255,
step=1,
value=200,
label="High Threshold",
)
self.sorting_priority = 100
self.use_soft_projection_in_hr_fix = True
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
input_image, remove_pad = resize_image_with_pad(input_image, resolution)
canny_image = cv2.cvtColor(
cv2.Canny(input_image, int(slider_1), int(slider_2)), cv2.COLOR_GRAY2RGB
)
return remove_pad(canny_image)
class PreprocessorInvert(Preprocessor):
def __init__(self):
super().__init__(name="invert")
self._label = "invert (from white bg & black line)"
self.tags = [
"Canny",
"Lineart",
"Scribble",
"Sketch",
"MLSD",
]
self.slider_resolution = PreprocessorParameter(visible=False)
self.sorting_priority = 20
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
return 255 - HWC3(input_image)
class PreprocessorBlurGaussian(Preprocessor):
def __init__(self):
super().__init__(name="blur_gaussian")
self.slider_1 = PreprocessorParameter(
label="Sigma", minimum=64, maximum=2048, value=512
)
self.tags = ["Tile", "Blur"]
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
img = remove_pad(img)
result = cv2.GaussianBlur(img, (0, 0), float(slider_1))
return result
class PreprocessorScribbleXdog(Preprocessor):
def __init__(self):
super().__init__(name="scribble_xdog")
self.slider_1 = PreprocessorParameter(
label="XDoG Threshold", minimum=1, maximum=64, value=32
)
self.tags = [
"Scribble",
"Sketch",
"SparseCtrl",
]
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
g1 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 0.5)
g2 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 5.0)
dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
result = np.zeros_like(img, dtype=np.uint8)
result[2 * (255 - dog) > slider_1] = 255
return remove_pad(result)
class PreprocessorShuffle(Preprocessor):
def __init__(self):
super().__init__(name="shuffle")
self.tags = ["Shuffle"]
self.model_shuffle = None
# Fix res to 512.
self.slider_resolution = PreprocessorParameter(value=512, visible=False)
def cached_call(self, *args, **kwargs):
"""No cache for shuffle, as each call depends on different numpy seed."""
return self(*args, **kwargs)
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
img = remove_pad(img)
if self.model_shuffle is None:
from annotator.shuffle import ContentShuffleDetector
self.model_shuffle = ContentShuffleDetector()
result = self.model_shuffle(img)
return result
Preprocessor.add_supported_preprocessor(PreprocessorNone())
Preprocessor.add_supported_preprocessor(PreprocessorCanny())
Preprocessor.add_supported_preprocessor(PreprocessorInvert())
Preprocessor.add_supported_preprocessor(PreprocessorBlurGaussian())
Preprocessor.add_supported_preprocessor(PreprocessorScribbleXdog())
Preprocessor.add_supported_preprocessor(PreprocessorShuffle())

View File

@ -0,0 +1,48 @@
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
class PreprocessorNormalDsine(Preprocessor):
def __init__(self):
super().__init__(name="normal_dsine")
self.tags = ["NormalMap"]
self.slider_1 = PreprocessorParameter(
minimum=0,
maximum=360,
step=0.1,
value=60,
label="Fov",
)
self.slider_2 = PreprocessorParameter(
minimum=1,
maximum=20,
step=1,
value=5,
label="Iterations",
)
self.model = None
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
if self.model is None:
from annotator.normaldsine import NormalDsineDetector
self.model = NormalDsineDetector()
result = self.model(
input_image,
new_fov=float(slider_1),
iterations=int(slider_2),
resulotion=resolution,
)
self.model.unload_model()
return result
Preprocessor.add_supported_preprocessor(PreprocessorNormalDsine())

View File

@ -0,0 +1,189 @@
from abc import ABC, abstractmethod
from typing import List, ClassVar, Dict, Optional, Set
from dataclasses import dataclass, field
from modules import shared
from scripts.logging import logger
from scripts.utils import ndarray_lru_cache
CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0)
@dataclass
class PreprocessorParameter:
"""
Class representing a parameter for a preprocessor.
Attributes:
label (str): The label for the parameter.
minimum (float): The minimum value of the parameter. Default is 0.0.
maximum (float): The maximum value of the parameter. Default is 1.0.
step (float): The step size for the parameter. Default is 0.01.
value (float): The initial value of the parameter. Default is 0.5.
visible (bool): Whether the parameter is visible or not. Default is False.
"""
label: str = "EMPTY_LABEL"
minimum: float = 0.0
maximum: float = 1.0
step: float = 0.01
value: float = 0.5
visible: bool = True
@property
def gradio_update_kwargs(self) -> dict:
return dict(
minimum=self.minimum,
maximum=self.maximum,
step=self.step,
label=self.label,
value=self.value,
visible=self.visible,
)
@property
def api_json(self) -> dict:
return dict(
name=self.label,
value=self.value,
min=self.minimum,
max=self.maximum,
step=self.step,
)
@dataclass
class Preprocessor(ABC):
"""
Class representing a preprocessor.
Attributes:
name (str): The name of the preprocessor.
tags (List[str]): The tags associated with the preprocessor.
slider_resolution (PreprocessorParameter): The parameter representing the resolution of the slider.
slider_1 (PreprocessorParameter): The first parameter of the slider.
slider_2 (PreprocessorParameter): The second parameter of the slider.
slider_3 (PreprocessorParameter): The third parameter of the slider.
show_control_mode (bool): Whether to show the control mode or not.
do_not_need_model (bool): Whether the preprocessor needs a model or not.
sorting_priority (int): The sorting priority of the preprocessor.
corp_image_with_a1111_mask_when_in_img2img_inpaint_tab (bool): Whether to crop the image with a1111 mask when in img2img inpaint tab or not.
fill_mask_with_one_when_resize_and_fill (bool): Whether to fill the mask with one when resizing and filling or not.
use_soft_projection_in_hr_fix (bool): Whether to use soft projection in hr fix or not.
expand_mask_when_resize_and_fill (bool): Whether to expand the mask when resizing and filling or not.
"""
name: str
_label: str = None
tags: List[str] = field(default_factory=list)
slider_resolution = PreprocessorParameter(
label="Resolution",
minimum=64,
maximum=2048,
value=512,
step=8,
visible=True,
)
slider_1 = PreprocessorParameter(visible=False)
slider_2 = PreprocessorParameter(visible=False)
slider_3 = PreprocessorParameter(visible=False)
returns_image: bool = True
show_control_mode = True
do_not_need_model = False
sorting_priority = 0 # higher goes to top in the list
corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
fill_mask_with_one_when_resize_and_fill = False
use_soft_projection_in_hr_fix = False
expand_mask_when_resize_and_fill = False
all_processors: ClassVar[Dict[str, "Preprocessor"]] = {}
all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {}
@property
def label(self) -> str:
"""Display name on UI."""
return self._label if self._label is not None else self.name
@classmethod
def add_supported_preprocessor(cls, p: "Preprocessor"):
assert p.label not in cls.all_processors, f"{p.label} already registered!"
cls.all_processors[p.label] = p
assert p.name not in cls.all_processors_by_name, f"{p.name} already registered!"
cls.all_processors_by_name[p.name] = p
logger.debug(f"{p.name} registered. Total preprocessors ({len(cls.all_processors)}).")
@classmethod
def get_preprocessor(cls, name: str) -> Optional["Preprocessor"]:
return cls.all_processors.get(name, cls.all_processors_by_name.get(name, None))
@classmethod
def get_sorted_preprocessors(cls) -> List["Preprocessor"]:
preprocessors = [p for k, p in cls.all_processors.items() if k != "none"]
return [cls.all_processors["none"]] + sorted(
preprocessors,
key=lambda x: str(x.sorting_priority).zfill(8) + x.label,
reverse=True,
)
@classmethod
def get_all_preprocessor_tags(cls):
tags = set()
for _, p in cls.all_processors.items():
tags.update(set(p.tags))
return ["All"] + sorted(list(tags))
@classmethod
def get_filtered_preprocessors(cls, tag: str) -> List["Preprocessor"]:
if tag == "All":
return cls.all_processors
return [
p
for p in cls.get_sorted_preprocessors()
if tag in p.tags or p.label == "none"
]
@classmethod
def get_default_preprocessor(cls, tag: str) -> "Preprocessor":
ps = cls.get_filtered_preprocessors(tag)
assert len(ps) > 0
return ps[0] if len(ps) == 1 else ps[1]
@classmethod
def tag_to_filters(cls, tag: str) -> Set[str]:
filters_aliases = {
"instructp2p": ["ip2p"],
"segmentation": ["seg"],
"normalmap": ["normal"],
"t2i-adapter": ["t2i_adapter", "t2iadapter", "t2ia"],
"ip-adapter": ["ip_adapter", "ipadapter"],
"openpose": ["openpose", "densepose"],
"instant-id": ["instant_id", "instantid"],
}
tag = tag.lower()
return set([tag] + filters_aliases.get(tag, []))
@ndarray_lru_cache(max_size=CACHE_SIZE)
def cached_call(self, *args, **kwargs):
logger.debug(f"Calling preprocessor {self.name} outside of cache.")
return self(*args, **kwargs)
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
return self.__hash__() == other.__hash__()
@abstractmethod
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs,
):
pass

View File

@ -179,3 +179,28 @@ def align_dim_latent(x: int) -> int:
Stable diffusion 1:8 ratio for latent/pixel, i.e., Stable diffusion 1:8 ratio for latent/pixel, i.e.,
1 latent unit == 8 pixel unit.""" 1 latent unit == 8 pixel unit."""
return (x // 8) * 8 return (x // 8) * 8
def pad64(x):
return int(np.ceil(float(x) / 64.0) * 64 - x)
def safer_memory(x):
# Fix many MAC/AMD problems
return np.ascontiguousarray(x.copy()).copy()
def resize_image_with_pad(img, resolution):
H_raw, W_raw, _ = img.shape
k = float(resolution) / float(min(H_raw, W_raw))
interpolation = cv2.INTER_CUBIC if k > 1 else cv2.INTER_AREA
H_target = int(np.round(float(H_raw) * k))
W_target = int(np.round(float(W_raw) * k))
img = cv2.resize(img, (W_target, H_target), interpolation=interpolation)
H_pad, W_pad = pad64(H_target), pad64(W_target)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
def remove_pad(x):
return safer_memory(x[:H_target, :W_target])
return safer_memory(img_padded), remove_pad

View File

@ -4,8 +4,9 @@ import numpy as np
from modules import scripts, shared from modules import scripts, shared
try: try:
from scripts.global_state import update_cn_models, cn_models_names, cn_preprocessor_modules from scripts.global_state import update_cn_models, cn_models_names
from scripts.external_code import ResizeMode, ControlMode from scripts.external_code import ResizeMode, ControlMode
from scripts.supported_preprocessor import Preprocessor
except (ImportError, NameError): except (ImportError, NameError):
import_error = True import_error = True
@ -408,7 +409,7 @@ def add_axis_options(xyz_grid):
return [e.value for e in ResizeMode] return [e.value for e in ResizeMode]
def choices_preprocessor(): def choices_preprocessor():
return list(cn_preprocessor_modules) return list(Preprocessor.all_processors.keys())
def make_excluded_list(): def make_excluded_list():
pattern = re.compile(r"\[(\w+)\]") pattern = re.compile(r"\[(\w+)\]")

View File

@ -1,4 +1,3 @@
from typing import Any, Dict, List
import unittest import unittest
from PIL import Image from PIL import Image
import numpy as np import numpy as np
@ -8,7 +7,7 @@ import importlib
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils") utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
from scripts import external_code, processor from scripts import external_code
from scripts.controlnet import prepare_mask, Script, set_numpy_seed from scripts.controlnet import prepare_mask, Script, set_numpy_seed
from modules import processing from modules import processing
@ -122,25 +121,6 @@ class TestScript(unittest.TestCase):
[[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8 [[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8
) )
def test_bound_check_params(self):
def param_required(module: str, param: str) -> bool:
configs = processor.preprocessor_sliders_config[module]
config_index = ("processor_res", "threshold_a", "threshold_b").index(param)
return config_index < len(configs) and configs[config_index] is not None
for module in processor.preprocessor_sliders_config.keys():
for param in ("processor_res", "threshold_a", "threshold_b"):
with self.subTest(param=param, module=module):
unit = external_code.ControlNetUnit(
module=module,
**{param: -100},
)
unit.bound_check_params()
if param_required(module, param):
self.assertGreaterEqual(getattr(unit, param), 0)
else:
self.assertEqual(getattr(unit, param), -100)
def test_choose_input_image(self): def test_choose_input_image(self):
with self.subTest(name="no image"): with self.subTest(name="no image"):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):

View File

@ -1,67 +0,0 @@
import importlib
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
from scripts.global_state import select_control_type, ui_preprocessor_keys
from scripts.enums import StableDiffusionVersion
dummy_value = "dummy"
cn_models = {
"None": dummy_value,
"canny_sd15": dummy_value,
"canny_sdxl": dummy_value,
}
# Tests for the select_control_type function
class TestSelectControlType:
def test_all_control_type(self):
result = select_control_type("All", cn_models=cn_models)
assert result == (
[ui_preprocessor_keys, list(cn_models.keys()), "none", "None"]
), "Expected all preprocessors and models"
def test_sd_version(self):
(_, filtered_model_list, _, default_model) = select_control_type(
"Canny", sd_version=StableDiffusionVersion.UNKNOWN, cn_models=cn_models
)
assert filtered_model_list == [
"None",
"canny_sd15",
"canny_sdxl",
], "UNKNOWN sd version should match all models"
assert default_model == "canny_sd15"
(_, filtered_model_list, _, default_model) = select_control_type(
"Canny", sd_version=StableDiffusionVersion.SD1x, cn_models=cn_models
)
assert filtered_model_list == [
"None",
"canny_sd15",
], "sd1x version should only sd1x"
assert default_model == "canny_sd15"
(_, filtered_model_list, _, default_model) = select_control_type(
"Canny", sd_version=StableDiffusionVersion.SDXL, cn_models=cn_models
)
assert filtered_model_list == [
"None",
"canny_sdxl",
], "sdxl version should only sdxl"
assert default_model == "canny_sdxl"
def test_invert_preprocessor(self):
for control_type in ("Canny", "Lineart", "Scribble/Sketch", "MLSD"):
filtered_preprocessor_list, _, _, _ = select_control_type(
control_type, cn_models=cn_models
)
assert any(
"invert" in module.lower() for module in filtered_preprocessor_list
)
def test_no_module_available(self):
(_, filtered_model_list, _, default_model) = select_control_type(
"Depth", cn_models=cn_models
)
assert filtered_model_list == ["None"]
assert default_model == "None"

View File

@ -1,24 +0,0 @@
import unittest
import importlib
import requests
utils = importlib.import_module(
'extensions.sd-webui-controlnet.tests.utils', 'utils')
from scripts.processor import preprocessor_filters
class TestControlTypes(unittest.TestCase):
def test_fetching_control_types(self):
response = requests.get(utils.BASE_URL + "/controlnet/control_types")
self.assertEqual(response.status_code, 200)
result = response.json()
self.assertIn('control_types', result)
for control_type in preprocessor_filters:
self.assertIn(control_type, result['control_types'])
if __name__ == "__main__":
unittest.main()

View File

@ -90,9 +90,11 @@ def test_invalid_param(gen_type, param_name):
unit_overrides={param_name: -1}, unit_overrides={param_name: -1},
input_image=girl_img, input_image=girl_img,
).exec() ).exec()
assert log_context.is_in_console_logs([ assert log_context.is_in_console_logs(
[
f"[canny.{param_name}] Invalid value(-1), using default value", f"[canny.{param_name}] Invalid value(-1), using default value",
]) ]
)
@pytest.mark.parametrize("save_map", [True, False]) @pytest.mark.parametrize("save_map", [True, False])
@ -285,3 +287,20 @@ def test_lama_outpaint():
"resize_mode": "Resize and Fill", # OUTER_FIT "resize_mode": "Resize and Fill", # OUTER_FIT
}, },
).exec() ).exec()
@disable_in_cq
def test_ip_adapter_auto():
with console_log_context() as log_context:
assert APITestTemplate(
"txt2img_ip_adapter_auto",
"txt2img",
payload_overrides={},
unit_overrides={
"image": girl_img,
"model": get_model("ip-adapter_sd15"),
"module": "ip-adapter-auto",
},
).exec()
assert log_context.is_in_console_logs(["ip-adapter-auto => ip-adapter_clip_h"])

View File

@ -0,0 +1,172 @@
import pytest
import requests
from .template import APITestTemplate
expected_module_names = {
"animal_openpose",
"anime_face_segment",
"blur_gaussian",
"canny",
"clip_vision",
"color",
"densepose",
"densepose_parula",
"depth",
"depth_anything",
"depth_hand_refiner",
"depth_leres",
"depth_leres++",
"depth_zoe",
"dw_openpose_full",
"hed",
"hed_safe",
"inpaint",
"inpaint_only",
"inpaint_only+lama",
"instant_id_face_embedding",
"instant_id_face_keypoints",
"invert",
"ip-adapter-auto",
"ip-adapter_clip_sd15",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"lineart",
"lineart_anime",
"lineart_anime_denoise",
"lineart_coarse",
"lineart_standard",
"mediapipe_face",
"mlsd",
"none",
"normal_bae",
"normal_dsine",
"normal_map",
"oneformer_ade20k",
"oneformer_coco",
"openpose",
"openpose_face",
"openpose_faceonly",
"openpose_full",
"openpose_hand",
"pidinet",
"pidinet_safe",
"pidinet_scribble",
"pidinet_sketch",
"recolor_intensity",
"recolor_luminance",
"reference_adain",
"reference_adain+attn",
"reference_only",
"revision_clipvision",
"revision_ignore_prompt",
"scribble_hed",
"scribble_xdog",
"segmentation",
"shuffle",
"te_hed",
"threshold",
"tile_colorfix",
"tile_colorfix+sharp",
"tile_resample",
}
# Display name (label)
expected_module_alias = {
"animal_openpose",
"blur_gaussian",
"canny",
"densepose (pruple bg & purple torso)",
"densepose_parula (black bg & blue torso)",
"depth_anything",
"depth_hand_refiner",
"depth_leres",
"depth_leres++",
"depth_midas",
"depth_zoe",
"dw_openpose_full",
"inpaint_global_harmonious",
"inpaint_only",
"inpaint_only+lama",
"instant_id_face_embedding",
"instant_id_face_keypoints",
"invert (from white bg & black line)",
"ip-adapter-auto",
"ip-adapter_clip_g",
"ip-adapter_clip_h",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"lineart_anime",
"lineart_anime_denoise",
"lineart_coarse",
"lineart_realistic",
"lineart_standard (from white bg & black line)",
"mediapipe_face",
"mlsd",
"none",
"normal_bae",
"normal_dsine",
"normal_midas",
"openpose",
"openpose_face",
"openpose_faceonly",
"openpose_full",
"openpose_hand",
"recolor_intensity",
"recolor_luminance",
"reference_adain",
"reference_adain+attn",
"reference_only",
"revision_clipvision",
"revision_ignore_prompt",
"scribble_hed",
"scribble_pidinet",
"scribble_xdog",
"seg_anime_face",
"seg_ofade20k",
"seg_ofcoco",
"seg_ufade20k",
"shuffle",
"softedge_hed",
"softedge_hedsafe",
"softedge_pidinet",
"softedge_pidisafe",
"softedge_teed",
"t2ia_color_grid",
"t2ia_sketch_pidi",
"t2ia_style_clipvision",
"threshold",
"tile_colorfix",
"tile_colorfix+sharp",
"tile_resample",
}
@pytest.mark.parametrize("alias", ("true", "false"))
def test_module_list(alias):
json_resp = requests.get(
APITestTemplate.BASE_URL + f"controlnet/module_list?alias_names={alias}"
).json()
module_list = json_resp["module_list"]
module_detail: dict = json_resp["module_detail"]
expected_list = expected_module_alias if alias == "true" else expected_module_names
assert set(module_list).issuperset(expected_list), expected_list - set(module_list)
assert set(module_list) == set(module_detail.keys())
assert module_detail["canny"] == dict(
model_free=False,
sliders=[
{
"name": "Resolution",
"value": 512,
"min": 64,
"max": 2048,
"step": 8,
},
{"name": "Low Threshold", "value": 100, "min": 1, "max": 255, "step": 1},
{"name": "High Threshold", "value": 200, "min": 1, "max": 255, "step": 1},
],
)