✨ Save/Use CLIP output in API (#2590)
* Add API support * Add assertion * Add test * nitpull/2758/head
parent
eca5b0acfd
commit
bbcae309d1
|
|
@ -1,9 +1,14 @@
|
|||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from copy import copy
|
||||
from typing import List, Any, Optional, Union, Tuple, Dict
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modules import scripts, processing, shared
|
||||
from modules.safe import unsafe_torch_load
|
||||
from scripts import global_state
|
||||
from scripts.processor import preprocessor_sliders_config, model_free_preprocessors
|
||||
from scripts.logging import logger
|
||||
|
|
@ -193,6 +198,12 @@ class ControlNetUnit:
|
|||
# even advanced_weighting is set.
|
||||
advanced_weighting: Optional[List[float]] = None
|
||||
|
||||
# The tensor input for ipadapter. When this field is set in the API,
|
||||
# the base64string will be interpret by torch.load to reconstruct ipadapter
|
||||
# preprocessor output.
|
||||
# Currently the option is only accessible in API calls.
|
||||
ipadapter_input: Optional[List[Any]] = None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ControlNetUnit):
|
||||
return False
|
||||
|
|
@ -215,8 +226,10 @@ class ControlNetUnit:
|
|||
return [
|
||||
"image",
|
||||
"enabled",
|
||||
# Note: "advanced_weighting" is excluded as it is an API-only field.
|
||||
# API-only fields.
|
||||
"advanced_weighting",
|
||||
"ipadapter_input",
|
||||
# End of API-only fields.
|
||||
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
|
||||
# provide much information when restoring the unit.
|
||||
"inpaint_crop_input_image",
|
||||
|
|
@ -375,6 +388,7 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
|
|||
if isinstance(unit, dict):
|
||||
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
|
||||
|
||||
# Handle mask
|
||||
mask = None
|
||||
if 'mask' in unit:
|
||||
mask = unit['mask']
|
||||
|
|
@ -388,6 +402,17 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
|
|||
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[
|
||||
'image'] else None
|
||||
|
||||
# Parse ipadapter_input
|
||||
if "ipadapter_input" in unit:
|
||||
def decode_base64(b: str) -> torch.Tensor:
|
||||
decoded_bytes = base64.b64decode(b)
|
||||
return unsafe_torch_load(io.BytesIO(decoded_bytes))
|
||||
|
||||
if isinstance(unit["ipadapter_input"], str):
|
||||
unit["ipadapter_input"] = [unit["ipadapter_input"]]
|
||||
|
||||
unit["ipadapter_input"] = [decode_base64(b) for b in unit["ipadapter_input"]]
|
||||
|
||||
if 'guess_mode' in unit:
|
||||
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.')
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import base64
|
||||
import io
|
||||
import torch
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
|
@ -36,6 +38,14 @@ def encode_np_to_base64(image):
|
|||
return api.encode_pil_to_base64(pil)
|
||||
|
||||
|
||||
def encode_tensor_to_base64(obj: torch.Tensor) -> str:
|
||||
"""Serialize the tensor data to base64 string."""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(obj, buffer)
|
||||
buffer.seek(0) # Rewind the buffer
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def controlnet_api(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/controlnet/version")
|
||||
async def version():
|
||||
|
|
@ -108,6 +118,9 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|||
if controlnet_module not in cached_cn_preprocessors:
|
||||
raise HTTPException(status_code=422, detail="Module not available")
|
||||
|
||||
if controlnet_module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"):
|
||||
raise HTTPException(status_code=422, detail="Module not supported")
|
||||
|
||||
if len(controlnet_input_images) == 0:
|
||||
raise HTTPException(status_code=422, detail="No image selected")
|
||||
|
||||
|
|
@ -139,30 +152,31 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|||
self.value = json_dict
|
||||
|
||||
json_acceptor = JsonAcceptor()
|
||||
|
||||
results.append(
|
||||
processor_module(
|
||||
img,
|
||||
res=unit.processor_res,
|
||||
thr_a=unit.threshold_a,
|
||||
thr_b=unit.threshold_b,
|
||||
json_pose_callback=json_acceptor.accept,
|
||||
low_vram=low_vram,
|
||||
)[0]
|
||||
detected_map, is_image = processor_module(
|
||||
img,
|
||||
res=unit.processor_res,
|
||||
thr_a=unit.threshold_a,
|
||||
thr_b=unit.threshold_b,
|
||||
json_pose_callback=json_acceptor.accept,
|
||||
low_vram=low_vram,
|
||||
)
|
||||
results.append(detected_map)
|
||||
|
||||
if "openpose" in controlnet_module:
|
||||
assert json_acceptor.value is not None
|
||||
poses.append(json_acceptor.value)
|
||||
|
||||
global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
|
||||
results64 = list(map(encode_to_base64, results))
|
||||
res = {"images": results64, "info": "Success"}
|
||||
if poses:
|
||||
res["poses"] = poses
|
||||
|
||||
res = {"info": "Success"}
|
||||
if is_image:
|
||||
res["images"] = [encode_to_base64(r) for r in results]
|
||||
if poses:
|
||||
res["poses"] = poses
|
||||
else:
|
||||
res["tensor"] = [encode_tensor_to_base64(r) for r in results]
|
||||
return res
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
pose_keypoints_2d: List[float]
|
||||
hand_right_keypoints_2d: Optional[List[float]]
|
||||
|
|
|
|||
|
|
@ -215,6 +215,103 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:
|
|||
return y
|
||||
|
||||
|
||||
def get_control(
|
||||
p: StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
idx: int,
|
||||
control_model_type: ControlModelType,
|
||||
preprocessor,
|
||||
):
|
||||
"""Get input for a ControlNet unit."""
|
||||
if unit.is_animate_diff_batch:
|
||||
unit = add_animate_diff_batch_input(p, unit)
|
||||
|
||||
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
|
||||
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
|
||||
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
|
||||
if isinstance(input_image, list):
|
||||
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
|
||||
input_images = input_image
|
||||
else: # Following operations are only for single input image.
|
||||
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
|
||||
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
|
||||
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
# inpaint_only+lama is special and required outpaint fix
|
||||
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
|
||||
input_images = [input_image]
|
||||
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
input_images[0],
|
||||
target_H=h,
|
||||
target_W=w,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
# Preprocessor result may depend on numpy random operations, use the
|
||||
# random seed in `StableDiffusionProcessing` to make the
|
||||
# preprocessor result reproducable.
|
||||
# Currently following preprocessors use numpy random:
|
||||
# - shuffle
|
||||
seed = set_numpy_seed(p)
|
||||
logger.debug(f"Use numpy seed {seed}.")
|
||||
logger.info(f"Using preprocessor: {unit.module}")
|
||||
logger.info(f'preprocessor resolution = {unit.processor_res}')
|
||||
|
||||
detected_maps = []
|
||||
def store_detected_map(detected_map, module: str) -> None:
|
||||
if unit.save_detected_map:
|
||||
detected_maps.append((detected_map, module))
|
||||
|
||||
def preprocess_input_image(input_image: np.ndarray):
|
||||
""" Preprocess single input image. """
|
||||
detected_map, is_image = preprocessor(
|
||||
input_image,
|
||||
res=unit.processor_res,
|
||||
thr_a=unit.threshold_a,
|
||||
thr_b=unit.threshold_b,
|
||||
low_vram=(
|
||||
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
|
||||
shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
|
||||
),
|
||||
)
|
||||
if high_res_fix:
|
||||
if is_image:
|
||||
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
|
||||
store_detected_map(hr_detected_map, unit.module)
|
||||
else:
|
||||
hr_control = detected_map
|
||||
else:
|
||||
hr_control = None
|
||||
|
||||
if is_image:
|
||||
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
|
||||
store_detected_map(detected_map, unit.module)
|
||||
else:
|
||||
control = detected_map
|
||||
store_detected_map(input_image, unit.module)
|
||||
|
||||
if control_model_type == ControlModelType.T2I_StyleAdapter:
|
||||
control = control['last_hidden_state']
|
||||
|
||||
if control_model_type == ControlModelType.ReVision:
|
||||
control = control['image_embeds']
|
||||
|
||||
if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM
|
||||
control = control.cpu()
|
||||
if hr_control is not None:
|
||||
hr_control = hr_control.cpu()
|
||||
|
||||
return control, hr_control
|
||||
|
||||
def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch):
|
||||
from tqdm import tqdm
|
||||
return tqdm(iterable) if use_tqdm else iterable
|
||||
|
||||
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)]))
|
||||
assert len(controls) == len(hr_controls)
|
||||
return controls, hr_controls, detected_maps
|
||||
|
||||
|
||||
class Script(scripts.Script, metaclass=(
|
||||
utils.TimeMeta if logger.level == logging.DEBUG else type)):
|
||||
|
||||
|
|
@ -852,7 +949,6 @@ class Script(scripts.Script, metaclass=(
|
|||
|
||||
self.latest_model_hash = p.sd_model.sd_model_hash
|
||||
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
|
||||
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
|
||||
|
||||
for idx, unit in enumerate(self.enabled_units):
|
||||
unit.bound_check_params()
|
||||
|
|
@ -887,92 +983,21 @@ class Script(scripts.Script, metaclass=(
|
|||
bind_control_lora(unet, control_lora)
|
||||
p.controlnet_control_loras.append(control_lora)
|
||||
|
||||
if unit.is_animate_diff_batch:
|
||||
unit = add_animate_diff_batch_input(p, unit)
|
||||
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
|
||||
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
|
||||
if isinstance(input_image, list):
|
||||
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
|
||||
input_images = input_image
|
||||
else: # Following operations are only for single input image.
|
||||
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
|
||||
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
|
||||
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
# inpaint_only+lama is special and required outpaint fix
|
||||
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
|
||||
input_images = [input_image]
|
||||
if unit.ipadapter_input is not None:
|
||||
# Use ipadapter_input from API call.
|
||||
assert control_model_type == ControlModelType.IPAdapter
|
||||
controls = unit.ipadapter_input
|
||||
hr_controls = unit.ipadapter_input
|
||||
else:
|
||||
controls, hr_controls, additional_maps = get_control(
|
||||
p, unit, idx, control_model_type, self.preprocessor[unit.module])
|
||||
detected_maps.extend(additional_maps)
|
||||
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
input_images[0],
|
||||
target_H=h,
|
||||
target_W=w,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
# Preprocessor result may depend on numpy random operations, use the
|
||||
# random seed in `StableDiffusionProcessing` to make the
|
||||
# preprocessor result reproducable.
|
||||
# Currently following preprocessors use numpy random:
|
||||
# - shuffle
|
||||
seed = set_numpy_seed(p)
|
||||
logger.debug(f"Use numpy seed {seed}.")
|
||||
logger.info(f"Using preprocessor: {unit.module}")
|
||||
logger.info(f'preprocessor resolution = {unit.processor_res}')
|
||||
|
||||
def store_detected_map(detected_map, module: str) -> None:
|
||||
if unit.save_detected_map:
|
||||
detected_maps.append((detected_map, module))
|
||||
|
||||
def preprocess_input_image(input_image: np.ndarray):
|
||||
""" Preprocess single input image. """
|
||||
detected_map, is_image = self.preprocessor[unit.module](
|
||||
input_image,
|
||||
res=unit.processor_res,
|
||||
thr_a=unit.threshold_a,
|
||||
thr_b=unit.threshold_b,
|
||||
low_vram=(
|
||||
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
|
||||
shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
|
||||
),
|
||||
)
|
||||
if high_res_fix:
|
||||
if is_image:
|
||||
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
|
||||
store_detected_map(hr_detected_map, unit.module)
|
||||
else:
|
||||
hr_control = detected_map
|
||||
else:
|
||||
hr_control = None
|
||||
|
||||
if is_image:
|
||||
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
|
||||
store_detected_map(detected_map, unit.module)
|
||||
else:
|
||||
control = detected_map
|
||||
store_detected_map(input_image, unit.module)
|
||||
|
||||
if control_model_type == ControlModelType.T2I_StyleAdapter:
|
||||
control = control['last_hidden_state']
|
||||
|
||||
if control_model_type == ControlModelType.ReVision:
|
||||
control = control['image_embeds']
|
||||
|
||||
if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM
|
||||
control = control.cpu()
|
||||
if hr_control is not None:
|
||||
hr_control = hr_control.cpu()
|
||||
|
||||
return control, hr_control
|
||||
|
||||
def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch):
|
||||
from tqdm import tqdm
|
||||
return tqdm(iterable) if use_tqdm else iterable
|
||||
|
||||
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)]))
|
||||
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:
|
||||
control = controls[0]
|
||||
hr_control = hr_controls[0]
|
||||
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
|
||||
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
|
||||
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
|
||||
if unit.accepts_multiple_inputs():
|
||||
ip_adapter_image_emb_cond = []
|
||||
|
|
|
|||
|
|
@ -306,6 +306,7 @@ class ControlNetUiGroup(object):
|
|||
|
||||
# API-only fields
|
||||
self.advanced_weighting = gr.State(None)
|
||||
self.ipadapter_input = gr.State(None)
|
||||
|
||||
ControlNetUiGroup.all_ui_groups.append(self)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,12 +42,6 @@ def detect_template(payload, output_name: str, status: int = 200):
|
|||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[clip_vision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589ADD1210>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_clipvision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFB00E0>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_ignore_prompt] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3C9A0>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sd15] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF5B740>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl_plus_vith] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3D0D0>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x00000158FF7753F0>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589B0414E0>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id_plus] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AEE3100>
|
||||
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[instant_id_face_embedding] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFF6CF0>
|
||||
|
||||
|
||||
# TODO: file issue on these failures.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
import requests
|
||||
|
||||
from .template import (
|
||||
APITestTemplate,
|
||||
realistic_girl_face_img,
|
||||
disable_in_cq,
|
||||
get_model,
|
||||
)
|
||||
|
||||
|
||||
def detect_template(payload, status: int = 200):
|
||||
url = APITestTemplate.BASE_URL + "controlnet/detect"
|
||||
resp = requests.post(url, json=payload)
|
||||
assert resp.status_code == status
|
||||
if status != 200:
|
||||
return
|
||||
|
||||
resp_json = resp.json()
|
||||
assert "tensor" in resp_json
|
||||
assert len(resp_json["tensor"]) == len(payload["controlnet_input_images"])
|
||||
return resp_json
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_ipadapter_clip_api():
|
||||
"""Use previously saved CLIP output in ipadapter run."""
|
||||
resp = detect_template(
|
||||
dict(
|
||||
controlnet_input_images=[realistic_girl_face_img],
|
||||
controlnet_module="ip-adapter_clip_h",
|
||||
)
|
||||
)
|
||||
ipadapter_input = resp["tensor"]
|
||||
APITestTemplate(
|
||||
"test_ipadapter_clip_api",
|
||||
"txt2img",
|
||||
payload_overrides={},
|
||||
unit_overrides={
|
||||
"ipadapter_input": ipadapter_input,
|
||||
"model": get_model("ip-adapter_sd15"),
|
||||
},
|
||||
).exec()
|
||||
Loading…
Reference in New Issue