Save/Use CLIP output in API (#2590)

* Add API support

* Add assertion

* Add test

* nit
pull/2758/head
Chenlei Hu 2024-04-15 15:08:51 -04:00 committed by GitHub
parent eca5b0acfd
commit bbcae309d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 206 additions and 105 deletions

View File

@ -1,9 +1,14 @@
import base64
import io
from dataclasses import dataclass
from enum import Enum
from copy import copy
from typing import List, Any, Optional, Union, Tuple, Dict
import torch
import numpy as np
from modules import scripts, processing, shared
from modules.safe import unsafe_torch_load
from scripts import global_state
from scripts.processor import preprocessor_sliders_config, model_free_preprocessors
from scripts.logging import logger
@ -193,6 +198,12 @@ class ControlNetUnit:
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
# Currently the option is only accessible in API calls.
ipadapter_input: Optional[List[Any]] = None
def __eq__(self, other):
if not isinstance(other, ControlNetUnit):
return False
@ -215,8 +226,10 @@ class ControlNetUnit:
return [
"image",
"enabled",
# Note: "advanced_weighting" is excluded as it is an API-only field.
# API-only fields.
"advanced_weighting",
"ipadapter_input",
# End of API-only fields.
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
# provide much information when restoring the unit.
"inpaint_crop_input_image",
@ -375,6 +388,7 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
if isinstance(unit, dict):
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
# Handle mask
mask = None
if 'mask' in unit:
mask = unit['mask']
@ -388,6 +402,17 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[
'image'] else None
# Parse ipadapter_input
if "ipadapter_input" in unit:
def decode_base64(b: str) -> torch.Tensor:
decoded_bytes = base64.b64decode(b)
return unsafe_torch_load(io.BytesIO(decoded_bytes))
if isinstance(unit["ipadapter_input"], str):
unit["ipadapter_input"] = [unit["ipadapter_input"]]
unit["ipadapter_input"] = [decode_base64(b) for b in unit["ipadapter_input"]]
if 'guess_mode' in unit:
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.')

View File

@ -1,5 +1,7 @@
from typing import List, Optional
import base64
import io
import torch
import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
@ -36,6 +38,14 @@ def encode_np_to_base64(image):
return api.encode_pil_to_base64(pil)
def encode_tensor_to_base64(obj: torch.Tensor) -> str:
"""Serialize the tensor data to base64 string."""
buffer = io.BytesIO()
torch.save(obj, buffer)
buffer.seek(0) # Rewind the buffer
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/version")
async def version():
@ -108,6 +118,9 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
if controlnet_module not in cached_cn_preprocessors:
raise HTTPException(status_code=422, detail="Module not available")
if controlnet_module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"):
raise HTTPException(status_code=422, detail="Module not supported")
if len(controlnet_input_images) == 0:
raise HTTPException(status_code=422, detail="No image selected")
@ -139,30 +152,31 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
self.value = json_dict
json_acceptor = JsonAcceptor()
results.append(
processor_module(
img,
res=unit.processor_res,
thr_a=unit.threshold_a,
thr_b=unit.threshold_b,
json_pose_callback=json_acceptor.accept,
low_vram=low_vram,
)[0]
detected_map, is_image = processor_module(
img,
res=unit.processor_res,
thr_a=unit.threshold_a,
thr_b=unit.threshold_b,
json_pose_callback=json_acceptor.accept,
low_vram=low_vram,
)
results.append(detected_map)
if "openpose" in controlnet_module:
assert json_acceptor.value is not None
poses.append(json_acceptor.value)
global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
results64 = list(map(encode_to_base64, results))
res = {"images": results64, "info": "Success"}
if poses:
res["poses"] = poses
res = {"info": "Success"}
if is_image:
res["images"] = [encode_to_base64(r) for r in results]
if poses:
res["poses"] = poses
else:
res["tensor"] = [encode_tensor_to_base64(r) for r in results]
return res
class Person(BaseModel):
pose_keypoints_2d: List[float]
hand_right_keypoints_2d: Optional[List[float]]

View File

@ -215,6 +215,103 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:
return y
def get_control(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
idx: int,
control_model_type: ControlModelType,
preprocessor,
):
"""Get input for a ControlNet unit."""
if unit.is_animate_diff_batch:
unit = add_animate_diff_batch_input(p, unit)
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
if isinstance(input_image, list):
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
input_images = input_image
else: # Following operations are only for single input image.
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
# inpaint_only+lama is special and required outpaint fix
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
input_images = [input_image]
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_images[0],
target_H=h,
target_W=w,
resize_mode=resize_mode,
)
# Preprocessor result may depend on numpy random operations, use the
# random seed in `StableDiffusionProcessing` to make the
# preprocessor result reproducable.
# Currently following preprocessors use numpy random:
# - shuffle
seed = set_numpy_seed(p)
logger.debug(f"Use numpy seed {seed}.")
logger.info(f"Using preprocessor: {unit.module}")
logger.info(f'preprocessor resolution = {unit.processor_res}')
detected_maps = []
def store_detected_map(detected_map, module: str) -> None:
if unit.save_detected_map:
detected_maps.append((detected_map, module))
def preprocess_input_image(input_image: np.ndarray):
""" Preprocess single input image. """
detected_map, is_image = preprocessor(
input_image,
res=unit.processor_res,
thr_a=unit.threshold_a,
thr_b=unit.threshold_b,
low_vram=(
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
),
)
if high_res_fix:
if is_image:
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
store_detected_map(hr_detected_map, unit.module)
else:
hr_control = detected_map
else:
hr_control = None
if is_image:
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
store_detected_map(detected_map, unit.module)
else:
control = detected_map
store_detected_map(input_image, unit.module)
if control_model_type == ControlModelType.T2I_StyleAdapter:
control = control['last_hidden_state']
if control_model_type == ControlModelType.ReVision:
control = control['image_embeds']
if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM
control = control.cpu()
if hr_control is not None:
hr_control = hr_control.cpu()
return control, hr_control
def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch):
from tqdm import tqdm
return tqdm(iterable) if use_tqdm else iterable
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)]))
assert len(controls) == len(hr_controls)
return controls, hr_controls, detected_maps
class Script(scripts.Script, metaclass=(
utils.TimeMeta if logger.level == logging.DEBUG else type)):
@ -852,7 +949,6 @@ class Script(scripts.Script, metaclass=(
self.latest_model_hash = p.sd_model.sd_model_hash
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
for idx, unit in enumerate(self.enabled_units):
unit.bound_check_params()
@ -887,92 +983,21 @@ class Script(scripts.Script, metaclass=(
bind_control_lora(unet, control_lora)
p.controlnet_control_loras.append(control_lora)
if unit.is_animate_diff_batch:
unit = add_animate_diff_batch_input(p, unit)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
if isinstance(input_image, list):
assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch
input_images = input_image
else: # Following operations are only for single input image.
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
# inpaint_only+lama is special and required outpaint fix
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
input_images = [input_image]
if unit.ipadapter_input is not None:
# Use ipadapter_input from API call.
assert control_model_type == ControlModelType.IPAdapter
controls = unit.ipadapter_input
hr_controls = unit.ipadapter_input
else:
controls, hr_controls, additional_maps = get_control(
p, unit, idx, control_model_type, self.preprocessor[unit.module])
detected_maps.extend(additional_maps)
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_images[0],
target_H=h,
target_W=w,
resize_mode=resize_mode,
)
# Preprocessor result may depend on numpy random operations, use the
# random seed in `StableDiffusionProcessing` to make the
# preprocessor result reproducable.
# Currently following preprocessors use numpy random:
# - shuffle
seed = set_numpy_seed(p)
logger.debug(f"Use numpy seed {seed}.")
logger.info(f"Using preprocessor: {unit.module}")
logger.info(f'preprocessor resolution = {unit.processor_res}')
def store_detected_map(detected_map, module: str) -> None:
if unit.save_detected_map:
detected_maps.append((detected_map, module))
def preprocess_input_image(input_image: np.ndarray):
""" Preprocess single input image. """
detected_map, is_image = self.preprocessor[unit.module](
input_image,
res=unit.processor_res,
thr_a=unit.threshold_a,
thr_b=unit.threshold_b,
low_vram=(
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
),
)
if high_res_fix:
if is_image:
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
store_detected_map(hr_detected_map, unit.module)
else:
hr_control = detected_map
else:
hr_control = None
if is_image:
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
store_detected_map(detected_map, unit.module)
else:
control = detected_map
store_detected_map(input_image, unit.module)
if control_model_type == ControlModelType.T2I_StyleAdapter:
control = control['last_hidden_state']
if control_model_type == ControlModelType.ReVision:
control = control['image_embeds']
if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM
control = control.cpu()
if hr_control is not None:
hr_control = hr_control.cpu()
return control, hr_control
def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch):
from tqdm import tqdm
return tqdm(iterable) if use_tqdm else iterable
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)]))
if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]:
control = controls[0]
hr_control = hr_controls[0]
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
if unit.accepts_multiple_inputs():
ip_adapter_image_emb_cond = []

View File

@ -306,6 +306,7 @@ class ControlNetUiGroup(object):
# API-only fields
self.advanced_weighting = gr.State(None)
self.ipadapter_input = gr.State(None)
ControlNetUiGroup.all_ui_groups.append(self)

View File

@ -42,12 +42,6 @@ def detect_template(payload, output_name: str, status: int = 200):
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[clip_vision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589ADD1210>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_clipvision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFB00E0>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_ignore_prompt] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3C9A0>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sd15] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF5B740>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl_plus_vith] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3D0D0>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x00000158FF7753F0>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589B0414E0>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id_plus] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AEE3100>
# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[instant_id_face_embedding] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFF6CF0>
# TODO: file issue on these failures.

View File

@ -0,0 +1,42 @@
import requests
from .template import (
APITestTemplate,
realistic_girl_face_img,
disable_in_cq,
get_model,
)
def detect_template(payload, status: int = 200):
url = APITestTemplate.BASE_URL + "controlnet/detect"
resp = requests.post(url, json=payload)
assert resp.status_code == status
if status != 200:
return
resp_json = resp.json()
assert "tensor" in resp_json
assert len(resp_json["tensor"]) == len(payload["controlnet_input_images"])
return resp_json
@disable_in_cq
def test_ipadapter_clip_api():
"""Use previously saved CLIP output in ipadapter run."""
resp = detect_template(
dict(
controlnet_input_images=[realistic_girl_face_img],
controlnet_module="ip-adapter_clip_h",
)
)
ipadapter_input = resp["tensor"]
APITestTemplate(
"test_ipadapter_clip_api",
"txt2img",
payload_overrides={},
unit_overrides={
"ipadapter_input": ipadapter_input,
"model": get_model("ip-adapter_sd15"),
},
).exec()