sd-webui-controlnet/scripts/api.py

230 lines
9.5 KiB
Python

from typing import Union
import numpy as np
from fastapi import FastAPI, Body
from PIL import Image
import copy
import pydantic
import sys
import gradio as gr
from modules.api.models import *
from modules.api import api
from scripts import external_code, global_state
def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return api.encode_pil_to_base64(image)
elif type(image) is np.ndarray:
return encode_np_to_base64(image)
else:
return ""
def encode_np_to_base64(image):
pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil)
cn_root_field_prefix = 'controlnet_'
cn_fields = {
"input_image": (str, Field(default="", title='ControlNet Input Image')),
"mask": (str, Field(default="", title='ControlNet Input Mask')),
"module": (str, Field(default="none", title='Controlnet Module')),
"model": (str, Field(default="None", title='Controlnet Model')),
"weight": (float, Field(default=1.0, title='Controlnet Weight')),
"resize_mode": (Union[int, str], Field(default="Crop and Resize", title='Controlnet Resize Mode')),
"lowvram": (bool, Field(default=False, title='Controlnet Low VRAM')),
"processor_res": (int, Field(default=64, title='Controlnet Processor Res')),
"threshold_a": (float, Field(default=64, title='Controlnet Threshold a')),
"threshold_b": (float, Field(default=64, title='Controlnet Threshold b')),
"guidance": (float, Field(default=1.0, title='ControlNet Guidance Strength')),
"guidance_start": (float, Field(0.0, title='ControlNet Guidance Start')),
"guidance_end": (float, Field(1.0, title='ControlNet Guidance End')),
"guessmode": (bool, Field(default=True, title="Guess Mode"))
}
def get_deprecated_cn_field(field_name: str, field):
field_type, field = field
field = copy.copy(field)
field.default = None
field.extra['_deprecated'] = True
if field_name in ('input_image', 'mask'):
field_type = List[field_type]
return f'{cn_root_field_prefix}{field_name}', (field_type, field)
def get_deprecated_field_default(field_name: str):
if field_name in ('input_image', 'mask'):
return []
return cn_fields[field_name][-1].default
ControlNetUnitRequest = pydantic.create_model('ControlNetUnitRequest', **cn_fields)
def create_controlnet_request_model(p_api_class):
class RequestModel(p_api_class):
class Config(p_api_class.__config__):
@staticmethod
def schema_extra(schema: dict, _):
props = {}
for k, v in schema.get('properties', {}).items():
if not v.get('_deprecated', False):
props[k] = v
if v.get('docs_default', None) is not None:
v['default'] = v['docs_default']
if props:
schema['properties'] = props
additional_fields = {
'controlnet_units': (List[ControlNetUnitRequest], Field(default=[], docs_default=[ControlNetUnitRequest()], description="ControlNet Processing Units")),
**dict(get_deprecated_cn_field(k, v) for k, v in cn_fields.items())
}
return pydantic.create_model(
f'ControlNet{p_api_class.__name__}',
__base__=RequestModel,
**additional_fields)
ControlNetTxt2ImgRequest = create_controlnet_request_model(StableDiffusionTxt2ImgProcessingAPI)
ControlNetImg2ImgRequest = create_controlnet_request_model(StableDiffusionImg2ImgProcessingAPI)
class ApiHijack(api.Api):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_api_route("/controlnet/txt2img", self.controlnet_txt2img, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/controlnet/img2img", self.controlnet_img2img, methods=["POST"], response_model=ImageToImageResponse)
def controlnet_txt2img(self, txt2img_request: ControlNetTxt2ImgRequest):
return self.controlnet_any2img(
any2img_request=txt2img_request,
original_callback=ApiHijack.text2imgapi,
is_img2img=False,
)
def controlnet_img2img(self, img2img_request: ControlNetImg2ImgRequest):
return self.controlnet_any2img(
any2img_request=img2img_request,
original_callback=ApiHijack.img2imgapi,
is_img2img=True,
)
def controlnet_any2img(self, any2img_request, original_callback, is_img2img):
warn_deprecated_route(is_img2img)
any2img_request = nest_deprecated_cn_fields(any2img_request)
alwayson_scripts = dict(any2img_request.alwayson_scripts)
any2img_request.alwayson_scripts.update({'ControlNet': {'args': [to_api_cn_unit(unit) for unit in any2img_request.controlnet_units]}})
controlnet_units = any2img_request.controlnet_units
delattr(any2img_request, 'controlnet_units')
result = original_callback(self, any2img_request)
result.parameters['controlnet_units'] = controlnet_units
result.parameters['alwayson_scripts'] = alwayson_scripts
return result
api.Api = ApiHijack
def nest_deprecated_cn_fields(any2img_request):
deprecated_cn_fields = {k: v for k, v in vars(any2img_request).items()
if k.startswith(cn_root_field_prefix) and k != 'controlnet_units'}
any2img_request = copy.copy(any2img_request)
for k in deprecated_cn_fields.keys():
delattr(any2img_request, k)
if all(v is None for v in deprecated_cn_fields.values()):
return any2img_request
deprecated_cn_fields = {k[len(cn_root_field_prefix):]: v for k, v in deprecated_cn_fields.items()}
for k, v in deprecated_cn_fields.items():
if v is None:
deprecated_cn_fields[k] = get_deprecated_field_default(k)
for k in ('input_image', 'mask'):
deprecated_cn_fields[k] = deprecated_cn_fields[k][0] if deprecated_cn_fields[k] else ""
any2img_request.controlnet_units.insert(0, ControlNetUnitRequest(**deprecated_cn_fields))
return any2img_request
def to_api_cn_unit(unit_request: ControlNetUnitRequest) -> external_code.ControlNetUnit:
input_image = external_code.to_base64_nparray(unit_request.input_image) if unit_request.input_image else None
mask = external_code.to_base64_nparray(unit_request.mask) if unit_request.mask else None
if input_image is not None and mask is not None:
input_image = (input_image, mask)
if unit_request.guidance < 1.0:
unit_request.guidance_end = unit_request.guidance
return external_code.ControlNetUnit(
module=unit_request.module,
model=unit_request.model,
weight=unit_request.weight,
image=input_image,
resize_mode=unit_request.resize_mode,
low_vram=unit_request.lowvram,
processor_res=unit_request.processor_res,
threshold_a=unit_request.threshold_a,
threshold_b=unit_request.threshold_b,
guidance_start=unit_request.guidance_start,
guidance_end=unit_request.guidance_end,
guess_mode=unit_request.guessmode
)
def warn_deprecated_route(is_img2img):
route = 'img2img' if is_img2img else 'txt2img'
warning_prefix = '[ControlNet] warning: '
print(f"{warning_prefix}using deprecated '/controlnet/{route}' route", file=sys.stderr)
print(f"{warning_prefix}consider using the '/sdapi/v1/{route}' route with the 'alwayson_scripts' json property instead", file=sys.stderr)
def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/version")
async def version():
return {"version": external_code.get_api_version()}
@app.get("/controlnet/model_list")
async def model_list():
up_to_date_model_list = external_code.get_models(update=True)
print(up_to_date_model_list)
return {"model_list": up_to_date_model_list}
@app.get("/controlnet/module_list")
async def module_list():
_module_list = external_code.get_modules()
print(_module_list)
return {"module_list": _module_list}
@app.post("/controlnet/detect")
async def detect(
controlnet_module: str = Body("none", title='Controlnet Module'),
controlnet_input_images: List[str] = Body([], title='Controlnet Input Images'),
controlnet_processor_res: int = Body(512, title='Controlnet Processor Resolution'),
controlnet_threshold_a: float = Body(64, title='Controlnet Threshold a'),
controlnet_threshold_b: float = Body(64, title='Controlnet Threshold b')
):
controlnet_module = global_state.reverse_preprocessor_aliases.get(controlnet_module, controlnet_module)
if controlnet_module not in global_state.cn_preprocessor_modules:
return {"images": [], "info": "Module not available"}
if len(controlnet_input_images) == 0:
return {"images": [], "info": "No image selected"}
print(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")
results = []
processor_module = global_state.cn_preprocessor_modules[controlnet_module]
for input_image in controlnet_input_images:
img = external_code.to_base64_nparray(input_image)
results.append(processor_module(img, res=controlnet_processor_res, thr_a=controlnet_threshold_a, thr_b=controlnet_threshold_b)[0])
global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
results64 = list(map(encode_to_base64, results))
return {"images": results64, "info": "Success"}
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(controlnet_api)
except:
pass