automatic/modules/api/control.py

307 lines
17 KiB
Python

from typing import Optional
from threading import Lock
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from modules import errors, shared, processing_helpers
from modules.logger import log
from modules.api import models, helpers
from modules.control import run
errors.install()
class ItemControl(BaseModel):
process: str = Field(title="Preprocessor", default="", description="Preprocessor name (e.g. 'Canny', 'OpenPose', 'Depth Anything'). Use /sdapi/v1/preprocessors to list available options.")
model: str = Field(title="Control model", default="", description="Control model filename or path. Use /sdapi/v2/control-models to list available models.")
strength: float = Field(title="Strength", default=1.0, description="How strongly the control model influences generation (0.0-2.0)")
start: float = Field(title="Start", default=0.0, description="Step fraction at which control begins (0.0-1.0)")
end: float = Field(title="End", default=1.0, description="Step fraction at which control ends (0.0-1.0)")
override: str = Field(title="Override image", default=None, description="Base64-encoded pre-processed control image that bypasses the preprocessor. Takes priority over 'image'.")
unit_type: Optional[str] = Field(title="Unit type", default=None, description="Control unit type: 'controlnet', 't2i adapter', 'xs', 'lite', 'reference', or 'ip'. Defaults to the request-level unit_type.")
mode: str = Field(title="Mode", default="default", description="Control mode for Union/ProMax models. Use /sdapi/v2/control-modes to list valid modes per model.")
guess: bool = Field(title="Guess mode", default=False, description="Enable guess mode, which removes the need for a prompt (ControlNet only)")
factor: float = Field(title="Adapter factor", default=1.0, description="Conditioning scale factor (T2I Adapter only)")
attention: str = Field(title="Attention type", default="Attention", description="Attention mechanism type (Reference units only)")
fidelity: float = Field(title="Fidelity", default=0.5, description="Style fidelity level 0.0-1.0 (Reference units only)")
query_weight: float = Field(title="Query weight", default=1.0, description="Attention query weight (Reference units only)")
adain_weight: float = Field(title="AdaIN weight", default=1.0, description="Adaptive instance normalization weight (Reference units only)")
process_params: Optional[dict] = Field(title="Preprocessor params", default=None, description="Override preprocessor defaults, e.g. {'low_threshold': 50, 'high_threshold': 150} for Canny. Keys must match the preprocessor's parameters.")
image: Optional[str] = Field(title="Image", default=None, description="Base64-encoded control input image. Alias for 'override' -- if both are set, 'override' takes priority.")
class ItemXYZ(BaseModel):
x_type: str = Field(title="X axis type", default='', description="Parameter name for X axis variation (e.g. 'Seed', 'Steps', 'CFG Scale')")
x_values: str = Field(title="X axis values", default='', description="Comma-separated values for X axis")
y_type: str = Field(title="Y axis type", default='', description="Parameter name for Y axis variation")
y_values: str = Field(title="Y axis values", default='', description="Comma-separated values for Y axis")
z_type: str = Field(title="Z axis type", default='', description="Parameter name for Z axis variation")
z_values: str = Field(title="Z axis values", default='', description="Comma-separated values for Z axis")
draw_legend: bool = Field(title="Draw legend", default=True, description="Draw axis labels on the output grid image")
include_grid: bool = Field(title="Include grid", default=True, description="Include the combined grid image in output")
include_subgrids: bool = Field(title="Include subgrids", default=False, description="Include intermediate sub-grid images when using 3 axes")
include_images: bool = Field(title="Include images", default=False, description="Include individual cell images in output alongside the grid")
include_time: bool = Field(title="Include time", default=False, description="Show generation time per cell in the legend")
include_text: bool = Field(title="Include text", default=False, description="Show generation parameters as text overlay")
ReqControl = models.create_model_from_signature(
func = run.control_run,
model_name = "StableDiffusionProcessingControl",
additional_fields = [
{"key": "sampler_name", "type": str, "default": "Default"},
{"key": "script_name", "type": Optional[str], "default": None},
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "ip_adapter", "type": Optional[list[models.ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face", "type": Optional[models.ItemFace], "default": None, "exclude": True},
{"key": "control", "type": Optional[list[ItemControl]], "default": [], "exclude": True},
{"key": "xyz", "type": Optional[ItemXYZ], "default": None, "exclude": True},
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
{"key": "init_control", "type": Optional[list], "default": None, "exclude": True},
]
)
if not hasattr(ReqControl, "__config__"):
ReqControl.__config__ = models.DummyConfig
class ResControl(BaseModel):
images: list[str] = Field(default=None, title="Images", description="Base64-encoded generated output images")
processed: list[str] = Field(default=None, title="Processed", description="Base64-encoded preprocessor output images (control maps)")
params: dict = Field(default={}, title="Settings", description="Echo of the request parameters used for generation")
info: str = Field(default="", title="Info", description="Generation info string with seed, sampler, and pipeline details")
class APIControl:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
self.default_script_arg = []
self.units = []
def sanitize_args(self, args: dict):
args = vars(args)
args.pop('sampler_name', None)
args.pop('alwayson_scripts', None)
args.pop('face', None)
args.pop('face_id', None)
args.pop('ip_adapter', None)
args.pop('save_images', None)
args.pop('init_control', None)
args.pop('mask_blur', None)
args.pop('inpaint_full_res', None)
args.pop('inpaint_full_res_padding', None)
args.pop('inpainting_mask_invert', None)
args['override_script_name'] = args.pop('script_name', None)
args['override_script_args'] = args.pop('script_args', None)
return args
def sanitize_b64(self, request):
def sanitize_str(args: list):
for idx in range(0, len(args)):
if isinstance(args[idx], str) and len(args[idx]) >= 1000:
args[idx] = f"<str {len(args[idx])}>"
if hasattr(request, "alwayson_scripts") and request.alwayson_scripts:
for script_name in request.alwayson_scripts.keys():
script_obj = request.alwayson_scripts[script_name]
if script_obj and "args" in script_obj and script_obj["args"]:
sanitize_str(script_obj["args"])
if hasattr(request, "script_args") and request.script_args:
sanitize_str(request.script_args)
if hasattr(request, 'override_script_args') and request.override_script_args:
request.pop('override_script_args', None)
def prepare_face_module(self, req):
if hasattr(req, "face") and req.face and not req.script_name and (not req.alwayson_scripts or "face" not in req.alwayson_scripts.keys()):
req.script_name = "face"
req.script_args = [
req.face.mode,
req.face.source_images,
req.face.ip_model,
req.face.ip_override_sampler,
req.face.ip_cache_model,
req.face.ip_strength,
req.face.ip_structure,
req.face.id_strength,
req.face.id_conditioning,
req.face.id_cache,
req.face.pm_trigger,
req.face.pm_strength,
req.face.pm_start,
req.face.fs_cache
]
del req.face
def prepare_xyz_grid(self, req):
if hasattr(req, "xyz") and req.xyz:
req.script_name = "xyz grid"
req.script_args = [
req.xyz.x_type, req.xyz.x_values, '',
req.xyz.y_type, req.xyz.y_values, '',
req.xyz.z_type, req.xyz.z_values, '',
False, # csv_mode
req.xyz.draw_legend,
False, # no_fixed_seeds
req.xyz.include_grid, req.xyz.include_subgrids, req.xyz.include_images,
req.xyz.include_time, req.xyz.include_text,
]
del req.xyz
def prepare_ip_adapter(self, request):
if hasattr(request, "ip_adapter") and request.ip_adapter:
args = { 'ip_adapter_names': [], 'ip_adapter_scales': [], 'ip_adapter_crops': [], 'ip_adapter_starts': [], 'ip_adapter_ends': [], 'ip_adapter_images': [], 'ip_adapter_masks': [] }
for ipadapter in request.ip_adapter:
if not ipadapter.images or len(ipadapter.images) == 0:
continue
args['ip_adapter_names'].append(ipadapter.adapter)
args['ip_adapter_scales'].append(ipadapter.scale)
args['ip_adapter_starts'].append(ipadapter.start)
args['ip_adapter_ends'].append(ipadapter.end)
args['ip_adapter_crops'].append(ipadapter.crop)
args['ip_adapter_images'].append([helpers.decode_base64_to_image(x) for x in ipadapter.images])
if ipadapter.masks:
args['ip_adapter_masks'].append([helpers.decode_base64_to_image(x) for x in ipadapter.masks])
del request.ip_adapter
return args
else:
return {}
def prepare_control(self, req):
from modules.control.unit import Unit, unit_types
req.units = []
default_type = req.unit_type if req.unit_type is not None else 'controlnet'
# Set top-level unit_type from first unit (control_run filters units by this)
if req.control:
first_type = req.control[0].unit_type if req.control[0].unit_type is not None else default_type
req.unit_type = first_type
for i in range(len(req.control)):
u = req.control[i]
ut = u.unit_type if u.unit_type is not None else default_type
if ut not in unit_types:
log.error(f'Control unknown unit type: type={ut} available={unit_types}')
continue
if (len(self.units) > i) and (self.units[i].process_id == u.process) and (self.units[i].model_id == u.model) and (self.units[i].type == ut):
unit = self.units[i]
unit.enabled = True
unit.strength = u.strength
unit.start = u.start
unit.end = u.end
else:
unit = Unit(
enabled = True,
unit_type = ut,
model_id = u.model,
process_id = u.process,
strength = u.strength,
start = u.start,
end = u.end,
)
# Extended per-unit properties
unit.guess = u.guess
unit.factor = u.factor
unit.attention = u.attention
unit.fidelity = u.fidelity
unit.query_weight = u.query_weight
unit.adain_weight = u.adain_weight
unit.process_params = u.process_params or {}
unit.update_choices(u.model)
# Keep mode as string — run.py expects str and converts to index
if u.mode != "default" and u.mode in unit.choices:
unit.mode = u.mode
else:
unit.mode = unit.choices[0] if unit.choices else "default"
# Always clear process.override so init_units() re-sets it from unit.override
# This ensures the preprocessor runs fresh every generation (not reusing stale cached results)
if unit.process is not None:
unit.process.override = None
# Override image: use 'override' field, fall back to 'image' alias
override_b64 = u.override if u.override is not None else u.image
if override_b64 is not None:
unit.override = helpers.decode_base64_to_image(override_b64)
req.units.append(unit)
self.units = req.units
del req.control
def post_control(self, req: ReqControl):
"""Run the control-guided generation pipeline with one or more control units.
Supports unit types: **controlnet**, **t2i adapter**, **xs**, **lite**, **reference**, and **ip**.
Each unit in the `control` list pairs a preprocessor with a control model. The preprocessor
transforms the input image into a control signal (e.g. edge map, depth map, pose skeleton),
which the control model uses to guide diffusion.
Set `process_params` on individual control units to override preprocessor defaults
(e.g. Canny thresholds, pose confidence) without changing the global processor settings.
Optional modules: `ip_adapter` for image-prompt conditioning, `face` for face-driven
generation (FaceID/FaceSwap/PhotoMaker/InstantID), and `xyz` for parameter grid sweeps.
Returns generated images, preprocessor output maps, the echoed request parameters, and
a generation info string.
"""
requested = req.control
self.prepare_face_module(req)
self.prepare_control(req)
self.prepare_xyz_grid(req)
# Merge init_control images into inits
init_control = getattr(req, "init_control", None)
decoded_inits = [helpers.decode_base64_to_image(x) for x in req.inits] if req.inits else None
if init_control:
extra_inits = [helpers.decode_base64_to_image(x) for x in init_control]
decoded_inits = (decoded_inits or []) + extra_inits
# Extract excluded fields before copy (Pydantic exclude=True drops them from copy)
extra = getattr(req, "extra", {}) or {}
# prepare args
args = req.copy(update={
"sampler_index": processing_helpers.get_sampler_index(req.sampler_name),
"is_generator": True,
"inputs": [helpers.decode_base64_to_image(x) for x in req.inputs] if req.inputs else None,
"inits": decoded_inits,
"mask": helpers.decode_base64_to_image(req.mask) if req.mask else None,
})
args = self.sanitize_args(args)
args['extra'] = extra
send_images = args.pop('send_images', True)
# run
with self.queue_lock:
jobid = shared.state.begin('API-CTL', api=True)
output_images = []
output_processed = []
output_info = ''
extra_p_args = {
'do_not_save_grid': not req.save_images,
'do_not_save_samples': not req.save_images,
**self.prepare_ip_adapter(req),
}
# Forward inpainting fields
for field in ('mask_blur', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert'):
val = getattr(req, field, None)
if val is not None:
extra_p_args[field] = val
run.control_set(extra_p_args)
# run
res = run.control_run(**args)
for item in res:
if len(item) > 0 and (isinstance(item[0], list) or item[0] is None): # output_images
output_images += item[0] if item[0] is not None else []
output_processed += [item[1]] if item[1] is not None else []
output_info += item[2] if len(item) > 2 and item[2] is not None else ''
elif isinstance(item, str):
output_info += item
else:
pass
shared.state.end(jobid, api=False)
# return
b64images = list(map(helpers.encode_pil_to_base64, output_images)) if send_images else []
b64processed = list(map(helpers.encode_pil_to_base64, output_processed)) if send_images else []
self.sanitize_b64(req)
req.units = requested
return ResControl(images=b64images, processed=b64processed, params=vars(req), info=output_info)