279 lines
12 KiB
Python
279 lines
12 KiB
Python
from typing import Union
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, Body
|
|
from PIL import Image
|
|
import copy
|
|
import pydantic
|
|
import sys
|
|
|
|
import gradio as gr
|
|
|
|
from modules.api.models import *
|
|
from modules.api import api
|
|
|
|
from scripts import external_code
|
|
from scripts.processor import *
|
|
|
|
def encode_to_base64(image):
|
|
if type(image) is str:
|
|
return image
|
|
elif type(image) is Image.Image:
|
|
return api.encode_pil_to_base64(image)
|
|
elif type(image) is np.ndarray:
|
|
return encode_np_to_base64(image)
|
|
else:
|
|
return ""
|
|
|
|
def encode_np_to_base64(image):
|
|
pil = Image.fromarray(image)
|
|
return api.encode_pil_to_base64(pil)
|
|
|
|
cn_root_field_prefix = 'controlnet_'
|
|
cn_fields = {
|
|
"input_image": (str, Field(default="", title='ControlNet Input Image')),
|
|
"mask": (str, Field(default="", title='ControlNet Input Mask')),
|
|
"module": (str, Field(default="none", title='Controlnet Module')),
|
|
"model": (str, Field(default="None", title='Controlnet Model')),
|
|
"weight": (float, Field(default=1.0, title='Controlnet Weight')),
|
|
"resize_mode": (Union[int, str], Field(default="Scale to Fit (Inner Fit)", title='Controlnet Resize Mode')),
|
|
"lowvram": (bool, Field(default=False, title='Controlnet Low VRAM')),
|
|
"processor_res": (int, Field(default=64, title='Controlnet Processor Res')),
|
|
"threshold_a": (float, Field(default=64, title='Controlnet Threshold a')),
|
|
"threshold_b": (float, Field(default=64, title='Controlnet Threshold b')),
|
|
"guidance": (float, Field(default=1.0, title='ControlNet Guidance Strength')),
|
|
"guidance_start": (float, Field(0.0, title='ControlNet Guidance Start')),
|
|
"guidance_end": (float, Field(1.0, title='ControlNet Guidance End')),
|
|
"guessmode": (bool, Field(default=True, title="Guess Mode")),
|
|
"rgbbgr_mode": (bool, Field(default=False, title="RGB to BGR")),
|
|
}
|
|
|
|
def get_deprecated_cn_field(field_name: str, field):
|
|
field_type, field = field
|
|
field = copy.copy(field)
|
|
field.default = None
|
|
field.extra['_deprecated'] = True
|
|
if field_name in ('input_image', 'mask'):
|
|
field_type = List[field_type]
|
|
return f'{cn_root_field_prefix}{field_name}', (field_type, field)
|
|
|
|
def get_deprecated_field_default(field_name: str):
|
|
if field_name in ('input_image', 'mask'):
|
|
return []
|
|
return cn_fields[field_name][-1].default
|
|
|
|
ControlNetUnitRequest = pydantic.create_model('ControlNetUnitRequest', **cn_fields)
|
|
|
|
def create_controlnet_request_model(p_api_class):
|
|
class RequestModel(p_api_class):
|
|
class Config(p_api_class.__config__):
|
|
@staticmethod
|
|
def schema_extra(schema: dict, _):
|
|
props = {}
|
|
for k, v in schema.get('properties', {}).items():
|
|
if not v.get('_deprecated', False):
|
|
props[k] = v
|
|
if v.get('docs_default', None) is not None:
|
|
v['default'] = v['docs_default']
|
|
if props:
|
|
schema['properties'] = props
|
|
|
|
additional_fields = {
|
|
'controlnet_units': (List[ControlNetUnitRequest], Field(default=[], docs_default=[ControlNetUnitRequest()], description="ControlNet Processing Units")),
|
|
**dict(get_deprecated_cn_field(k, v) for k, v in cn_fields.items())
|
|
}
|
|
|
|
return pydantic.create_model(
|
|
f'ControlNet{p_api_class.__name__}',
|
|
__base__=RequestModel,
|
|
**additional_fields)
|
|
|
|
ControlNetTxt2ImgRequest = create_controlnet_request_model(StableDiffusionTxt2ImgProcessingAPI)
|
|
ControlNetImg2ImgRequest = create_controlnet_request_model(StableDiffusionImg2ImgProcessingAPI)
|
|
|
|
class ApiHijack(api.Api):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.add_api_route("/controlnet/txt2img", self.controlnet_txt2img, methods=["POST"], response_model=TextToImageResponse)
|
|
self.add_api_route("/controlnet/img2img", self.controlnet_img2img, methods=["POST"], response_model=ImageToImageResponse)
|
|
|
|
def controlnet_txt2img(self, txt2img_request: ControlNetTxt2ImgRequest):
|
|
return self.controlnet_any2img(
|
|
any2img_request=txt2img_request,
|
|
original_callback=ApiHijack.text2imgapi,
|
|
is_img2img=False,
|
|
)
|
|
|
|
def controlnet_img2img(self, img2img_request: ControlNetImg2ImgRequest):
|
|
return self.controlnet_any2img(
|
|
any2img_request=img2img_request,
|
|
original_callback=ApiHijack.img2imgapi,
|
|
is_img2img=True,
|
|
)
|
|
|
|
def controlnet_any2img(self, any2img_request, original_callback, is_img2img):
|
|
warn_deprecated_route(is_img2img)
|
|
any2img_request = nest_deprecated_cn_fields(any2img_request)
|
|
alwayson_scripts = dict(any2img_request.alwayson_scripts)
|
|
any2img_request.alwayson_scripts.update({'ControlNet': {'args': [to_api_cn_unit(unit) for unit in any2img_request.controlnet_units]}})
|
|
controlnet_units = any2img_request.controlnet_units
|
|
delattr(any2img_request, 'controlnet_units')
|
|
result = original_callback(self, any2img_request)
|
|
result.parameters['controlnet_units'] = controlnet_units
|
|
result.parameters['alwayson_scripts'] = alwayson_scripts
|
|
return result
|
|
|
|
api.Api = ApiHijack
|
|
|
|
def nest_deprecated_cn_fields(any2img_request):
|
|
deprecated_cn_fields = {k: v for k, v in vars(any2img_request).items()
|
|
if k.startswith(cn_root_field_prefix) and k != 'controlnet_units'}
|
|
|
|
any2img_request = copy.copy(any2img_request)
|
|
for k in deprecated_cn_fields.keys():
|
|
delattr(any2img_request, k)
|
|
|
|
if all(v is None for v in deprecated_cn_fields.values()):
|
|
return any2img_request
|
|
|
|
deprecated_cn_fields = {k[len(cn_root_field_prefix):]: v for k, v in deprecated_cn_fields.items()}
|
|
for k, v in deprecated_cn_fields.items():
|
|
if v is None:
|
|
deprecated_cn_fields[k] = get_deprecated_field_default(k)
|
|
|
|
for k in ('input_image', 'mask'):
|
|
deprecated_cn_fields[k] = deprecated_cn_fields[k][0] if deprecated_cn_fields[k] else ""
|
|
|
|
any2img_request.controlnet_units.insert(0, ControlNetUnitRequest(**deprecated_cn_fields))
|
|
return any2img_request
|
|
|
|
def to_api_cn_unit(unit_request: ControlNetUnitRequest) -> external_code.ControlNetUnit:
|
|
input_image = external_code.to_base64_nparray(unit_request.input_image) if unit_request.input_image else None
|
|
mask = external_code.to_base64_nparray(unit_request.mask) if unit_request.mask else None
|
|
if input_image is not None and mask is not None:
|
|
input_image = (input_image, mask)
|
|
|
|
if unit_request.guidance < 1.0:
|
|
unit_request.guidance_end = unit_request.guidance
|
|
|
|
return external_code.ControlNetUnit(
|
|
module=unit_request.module,
|
|
model=unit_request.model,
|
|
weight=unit_request.weight,
|
|
image=input_image,
|
|
resize_mode=unit_request.resize_mode,
|
|
low_vram=unit_request.lowvram,
|
|
processor_res=unit_request.processor_res,
|
|
threshold_a=unit_request.threshold_a,
|
|
threshold_b=unit_request.threshold_b,
|
|
guidance_start=unit_request.guidance_start,
|
|
guidance_end=unit_request.guidance_end,
|
|
guess_mode=unit_request.guessmode,
|
|
rgbbgr_mode=unit_request.rgbbgr_mode,
|
|
)
|
|
|
|
def warn_deprecated_route(is_img2img):
|
|
route = 'img2img' if is_img2img else 'txt2img'
|
|
warning_prefix = '[ControlNet] warning: '
|
|
print(f"{warning_prefix}using deprecated '/controlnet/{route}' route", file=sys.stderr)
|
|
print(f"{warning_prefix}consider using the '/sdapi/v1/{route}' route with the 'alwayson_scripts' json property instead", file=sys.stderr)
|
|
|
|
def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|
@app.get("/controlnet/version")
|
|
async def version():
|
|
return {"version": external_code.get_api_version()}
|
|
|
|
@app.get("/controlnet/model_list")
|
|
async def model_list():
|
|
up_to_date_model_list = external_code.get_models(update=True)
|
|
print(up_to_date_model_list)
|
|
return {"model_list": up_to_date_model_list}
|
|
|
|
@app.get("/controlnet/module_list")
|
|
async def module_list():
|
|
_module_list = external_code.get_modules()
|
|
print(_module_list)
|
|
return {"module_list": _module_list}
|
|
|
|
@app.post("/controlnet/detect")
|
|
async def detect(
|
|
controlnet_module: str = Body("none", title='Controlnet Module'),
|
|
controlnet_input_images: List[str] = Body([], title='Controlnet Input Images'),
|
|
controlnet_processor_res: int = Body(512, title='Controlnet Processor Resolution'),
|
|
controlnet_threshold_a: float = Body(64, title='Controlnet Threshold a'),
|
|
controlnet_threshold_b: float = Body(64, title='Controlnet Threshold b')
|
|
):
|
|
|
|
available_modules = [
|
|
"none",
|
|
"canny",
|
|
"depth",
|
|
"depth_leres",
|
|
"fake_scribble",
|
|
"hed",
|
|
"mlsd",
|
|
"normal_map",
|
|
"openpose",
|
|
"segmentation",
|
|
"binary",
|
|
"color"
|
|
]
|
|
|
|
if controlnet_module not in available_modules:
|
|
return {"images": [], "info": "Module not available"}
|
|
if len(controlnet_input_images) == 0:
|
|
return {"images": [], "info": "No image selected"}
|
|
|
|
print(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")
|
|
|
|
results = []
|
|
|
|
for input_image in controlnet_input_images:
|
|
img = external_code.to_base64_nparray(input_image)
|
|
|
|
if controlnet_module == "canny":
|
|
results.append(canny(img, controlnet_processor_res, controlnet_threshold_a, controlnet_threshold_b)[0])
|
|
elif controlnet_module == "hed":
|
|
results.append(hed(img, controlnet_processor_res)[0])
|
|
elif controlnet_module == "mlsd":
|
|
results.append(mlsd(img, controlnet_processor_res, controlnet_threshold_a, controlnet_threshold_b)[0])
|
|
elif controlnet_module == "depth":
|
|
results.append(midas(img, controlnet_processor_res, np.pi * 2.0)[0])
|
|
elif controlnet_module == "normal_map":
|
|
results.append(midas_normal(img, controlnet_processor_res, np.pi * 2.0, controlnet_threshold_a)[0])
|
|
elif controlnet_module == "depth_leres":
|
|
results.append(leres(img, controlnet_processor_res, np.pi * 2.0, controlnet_threshold_a, controlnet_threshold_b)[0])
|
|
elif controlnet_module == "openpose":
|
|
results.append(openpose(img, controlnet_processor_res, False)[0])
|
|
elif controlnet_module == "fake_scribble":
|
|
results.append(fake_scribble(img, controlnet_processor_res)[0])
|
|
elif controlnet_module == "segmentation":
|
|
results.append(uniformer(img, controlnet_processor_res)[0])
|
|
elif controlnet_module == "binary":
|
|
results.append(binary(img, controlnet_processor_res, controlnet_threshold_a)[0])
|
|
elif controlnet_module == "color":
|
|
results.append(color(img, controlnet_processor_res)[0])
|
|
|
|
if controlnet_module == "hed":
|
|
unload_hed()
|
|
elif controlnet_module == "mlsd":
|
|
unload_mlsd()
|
|
elif controlnet_module == "depth" or controlnet_module == "normal_map":
|
|
unload_midas()
|
|
elif controlnet_module == "depth_leres":
|
|
unload_leres()
|
|
elif controlnet_module == "openpose":
|
|
unload_openpose()
|
|
elif controlnet_module == "segmentation":
|
|
unload_uniformer()
|
|
|
|
results64 = list(map(encode_to_base64, results))
|
|
return {"images": results64, "info": "Success"}
|
|
|
|
try:
|
|
import modules.script_callbacks as script_callbacks
|
|
|
|
script_callbacks.on_app_started(controlnet_api)
|
|
except:
|
|
pass
|