From bbcae309d1c0921c212bfb9719169db8826b795f Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 15 Apr 2024 15:08:51 -0400 Subject: [PATCH] :sparkles: Save/Use CLIP output in API (#2590) * Add API support * Add assertion * Add test * nit --- internal_controlnet/external_code.py | 27 ++- scripts/api.py | 46 +++-- scripts/controlnet.py | 189 +++++++++++-------- scripts/controlnet_ui/controlnet_ui_group.py | 1 + tests/web_api/detect_test.py | 6 - tests/web_api/ipadapter_clip_api.py | 42 +++++ 6 files changed, 206 insertions(+), 105 deletions(-) create mode 100644 tests/web_api/ipadapter_clip_api.py diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py index f7fa09c..35b1b1d 100644 --- a/internal_controlnet/external_code.py +++ b/internal_controlnet/external_code.py @@ -1,9 +1,14 @@ +import base64 +import io from dataclasses import dataclass from enum import Enum from copy import copy from typing import List, Any, Optional, Union, Tuple, Dict +import torch import numpy as np + from modules import scripts, processing, shared +from modules.safe import unsafe_torch_load from scripts import global_state from scripts.processor import preprocessor_sliders_config, model_free_preprocessors from scripts.logging import logger @@ -193,6 +198,12 @@ class ControlNetUnit: # even advanced_weighting is set. advanced_weighting: Optional[List[float]] = None + # The tensor input for ipadapter. When this field is set in the API, + # the base64string will be interpret by torch.load to reconstruct ipadapter + # preprocessor output. + # Currently the option is only accessible in API calls. + ipadapter_input: Optional[List[Any]] = None + def __eq__(self, other): if not isinstance(other, ControlNetUnit): return False @@ -215,8 +226,10 @@ class ControlNetUnit: return [ "image", "enabled", - # Note: "advanced_weighting" is excluded as it is an API-only field. + # API-only fields. "advanced_weighting", + "ipadapter_input", + # End of API-only fields. # Note: "inpaint_crop_image" is img2img inpaint only flag, which does not # provide much information when restoring the unit. "inpaint_crop_input_image", @@ -375,6 +388,7 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe if isinstance(unit, dict): unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} + # Handle mask mask = None if 'mask' in unit: mask = unit['mask'] @@ -388,6 +402,17 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[ 'image'] else None + # Parse ipadapter_input + if "ipadapter_input" in unit: + def decode_base64(b: str) -> torch.Tensor: + decoded_bytes = base64.b64decode(b) + return unsafe_torch_load(io.BytesIO(decoded_bytes)) + + if isinstance(unit["ipadapter_input"], str): + unit["ipadapter_input"] = [unit["ipadapter_input"]] + + unit["ipadapter_input"] = [decode_base64(b) for b in unit["ipadapter_input"]] + if 'guess_mode' in unit: logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') diff --git a/scripts/api.py b/scripts/api.py index 58838d8..f387469 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -1,5 +1,7 @@ from typing import List, Optional - +import base64 +import io +import torch import numpy as np from fastapi import FastAPI, Body from fastapi.exceptions import HTTPException @@ -36,6 +38,14 @@ def encode_np_to_base64(image): return api.encode_pil_to_base64(pil) +def encode_tensor_to_base64(obj: torch.Tensor) -> str: + """Serialize the tensor data to base64 string.""" + buffer = io.BytesIO() + torch.save(obj, buffer) + buffer.seek(0) # Rewind the buffer + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def controlnet_api(_: gr.Blocks, app: FastAPI): @app.get("/controlnet/version") async def version(): @@ -108,6 +118,9 @@ def controlnet_api(_: gr.Blocks, app: FastAPI): if controlnet_module not in cached_cn_preprocessors: raise HTTPException(status_code=422, detail="Module not available") + if controlnet_module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"): + raise HTTPException(status_code=422, detail="Module not supported") + if len(controlnet_input_images) == 0: raise HTTPException(status_code=422, detail="No image selected") @@ -139,30 +152,31 @@ def controlnet_api(_: gr.Blocks, app: FastAPI): self.value = json_dict json_acceptor = JsonAcceptor() - - results.append( - processor_module( - img, - res=unit.processor_res, - thr_a=unit.threshold_a, - thr_b=unit.threshold_b, - json_pose_callback=json_acceptor.accept, - low_vram=low_vram, - )[0] + detected_map, is_image = processor_module( + img, + res=unit.processor_res, + thr_a=unit.threshold_a, + thr_b=unit.threshold_b, + json_pose_callback=json_acceptor.accept, + low_vram=low_vram, ) + results.append(detected_map) if "openpose" in controlnet_module: assert json_acceptor.value is not None poses.append(json_acceptor.value) global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)() - results64 = list(map(encode_to_base64, results)) - res = {"images": results64, "info": "Success"} - if poses: - res["poses"] = poses - + res = {"info": "Success"} + if is_image: + res["images"] = [encode_to_base64(r) for r in results] + if poses: + res["poses"] = poses + else: + res["tensor"] = [encode_tensor_to_base64(r) for r in results] return res + class Person(BaseModel): pose_keypoints_2d: List[float] hand_right_keypoints_2d: Optional[List[float]] diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 7ccf236..3c7b910 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -215,6 +215,103 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor: return y +def get_control( + p: StableDiffusionProcessing, + unit: external_code.ControlNetUnit, + idx: int, + control_model_type: ControlModelType, + preprocessor, +): + """Get input for a ControlNet unit.""" + if unit.is_animate_diff_batch: + unit = add_animate_diff_batch_input(p, unit) + + high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) + h, w, hr_y, hr_x = Script.get_target_dimensions(p) + input_image, resize_mode = Script.choose_input_image(p, unit, idx) + if isinstance(input_image, list): + assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch + input_images = input_image + else: # Following operations are only for single input image. + input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode) + input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy + if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT: + # inpaint_only+lama is special and required outpaint fix + _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x) + input_images = [input_image] + + if unit.pixel_perfect: + unit.processor_res = external_code.pixel_perfect_resolution( + input_images[0], + target_H=h, + target_W=w, + resize_mode=resize_mode, + ) + # Preprocessor result may depend on numpy random operations, use the + # random seed in `StableDiffusionProcessing` to make the + # preprocessor result reproducable. + # Currently following preprocessors use numpy random: + # - shuffle + seed = set_numpy_seed(p) + logger.debug(f"Use numpy seed {seed}.") + logger.info(f"Using preprocessor: {unit.module}") + logger.info(f'preprocessor resolution = {unit.processor_res}') + + detected_maps = [] + def store_detected_map(detected_map, module: str) -> None: + if unit.save_detected_map: + detected_maps.append((detected_map, module)) + + def preprocess_input_image(input_image: np.ndarray): + """ Preprocess single input image. """ + detected_map, is_image = preprocessor( + input_image, + res=unit.processor_res, + thr_a=unit.threshold_a, + thr_b=unit.threshold_b, + low_vram=( + ("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and + shared.opts.data.get("controlnet_clip_detector_on_cpu", False) + ), + ) + if high_res_fix: + if is_image: + hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) + store_detected_map(hr_detected_map, unit.module) + else: + hr_control = detected_map + else: + hr_control = None + + if is_image: + control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w) + store_detected_map(detected_map, unit.module) + else: + control = detected_map + store_detected_map(input_image, unit.module) + + if control_model_type == ControlModelType.T2I_StyleAdapter: + control = control['last_hidden_state'] + + if control_model_type == ControlModelType.ReVision: + control = control['image_embeds'] + + if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM + control = control.cpu() + if hr_control is not None: + hr_control = hr_control.cpu() + + return control, hr_control + + def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch): + from tqdm import tqdm + return tqdm(iterable) if use_tqdm else iterable + + controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)])) + assert len(controls) == len(hr_controls) + return controls, hr_controls, detected_maps + + class Script(scripts.Script, metaclass=( utils.TimeMeta if logger.level == logging.DEBUG else type)): @@ -852,7 +949,6 @@ class Script(scripts.Script, metaclass=( self.latest_model_hash = p.sd_model.sd_model_hash high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) - h, w, hr_y, hr_x = Script.get_target_dimensions(p) for idx, unit in enumerate(self.enabled_units): unit.bound_check_params() @@ -887,92 +983,21 @@ class Script(scripts.Script, metaclass=( bind_control_lora(unet, control_lora) p.controlnet_control_loras.append(control_lora) - if unit.is_animate_diff_batch: - unit = add_animate_diff_batch_input(p, unit) - input_image, resize_mode = Script.choose_input_image(p, unit, idx) - cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None) - if isinstance(input_image, list): - assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch - input_images = input_image - else: # Following operations are only for single input image. - input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode) - input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy - if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT: - # inpaint_only+lama is special and required outpaint fix - _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x) - input_images = [input_image] + if unit.ipadapter_input is not None: + # Use ipadapter_input from API call. + assert control_model_type == ControlModelType.IPAdapter + controls = unit.ipadapter_input + hr_controls = unit.ipadapter_input + else: + controls, hr_controls, additional_maps = get_control( + p, unit, idx, control_model_type, self.preprocessor[unit.module]) + detected_maps.extend(additional_maps) - if unit.pixel_perfect: - unit.processor_res = external_code.pixel_perfect_resolution( - input_images[0], - target_H=h, - target_W=w, - resize_mode=resize_mode, - ) - # Preprocessor result may depend on numpy random operations, use the - # random seed in `StableDiffusionProcessing` to make the - # preprocessor result reproducable. - # Currently following preprocessors use numpy random: - # - shuffle - seed = set_numpy_seed(p) - logger.debug(f"Use numpy seed {seed}.") - logger.info(f"Using preprocessor: {unit.module}") - logger.info(f'preprocessor resolution = {unit.processor_res}') - - def store_detected_map(detected_map, module: str) -> None: - if unit.save_detected_map: - detected_maps.append((detected_map, module)) - - def preprocess_input_image(input_image: np.ndarray): - """ Preprocess single input image. """ - detected_map, is_image = self.preprocessor[unit.module]( - input_image, - res=unit.processor_res, - thr_a=unit.threshold_a, - thr_b=unit.threshold_b, - low_vram=( - ("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and - shared.opts.data.get("controlnet_clip_detector_on_cpu", False) - ), - ) - if high_res_fix: - if is_image: - hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) - store_detected_map(hr_detected_map, unit.module) - else: - hr_control = detected_map - else: - hr_control = None - - if is_image: - control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w) - store_detected_map(detected_map, unit.module) - else: - control = detected_map - store_detected_map(input_image, unit.module) - - if control_model_type == ControlModelType.T2I_StyleAdapter: - control = control['last_hidden_state'] - - if control_model_type == ControlModelType.ReVision: - control = control['image_embeds'] - - if is_image and unit.is_animate_diff_batch: # AnimateDiff save VRAM - control = control.cpu() - if hr_control is not None: - hr_control = hr_control.cpu() - - return control, hr_control - - def optional_tqdm(iterable, use_tqdm=unit.is_animate_diff_batch): - from tqdm import tqdm - return tqdm(iterable) if use_tqdm else iterable - - controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in optional_tqdm(input_images)])) if len(controls) == len(hr_controls) == 1 and control_model_type not in [ControlModelType.SparseCtrl]: control = controls[0] hr_control = hr_controls[0] elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]: + cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None) def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx): if unit.accepts_multiple_inputs(): ip_adapter_image_emb_cond = [] diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index dbd7016..b06c2d0 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -306,6 +306,7 @@ class ControlNetUiGroup(object): # API-only fields self.advanced_weighting = gr.State(None) + self.ipadapter_input = gr.State(None) ControlNetUiGroup.all_ui_groups.append(self) diff --git a/tests/web_api/detect_test.py b/tests/web_api/detect_test.py index 15f3a29..86bf295 100644 --- a/tests/web_api/detect_test.py +++ b/tests/web_api/detect_test.py @@ -42,12 +42,6 @@ def detect_template(payload, output_name: str, status: int = 200): # FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[clip_vision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589ADD1210> # FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_clipvision] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFB00E0> # FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[revision_ignore_prompt] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3C9A0> -# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sd15] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF5B740> -# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl_plus_vith] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AF3D0D0> -# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_clip_sdxl] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x00000158FF7753F0> -# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589B0414E0> -# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[ip-adapter_face_id_plus] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AEE3100> -# FAILED extensions/sd-webui-controlnet/tests/web_api/detect_test.py::test_detect_all_modules[instant_id_face_embedding] - PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001589AFF6CF0> # TODO: file issue on these failures. diff --git a/tests/web_api/ipadapter_clip_api.py b/tests/web_api/ipadapter_clip_api.py new file mode 100644 index 0000000..27251e3 --- /dev/null +++ b/tests/web_api/ipadapter_clip_api.py @@ -0,0 +1,42 @@ +import requests + +from .template import ( + APITestTemplate, + realistic_girl_face_img, + disable_in_cq, + get_model, +) + + +def detect_template(payload, status: int = 200): + url = APITestTemplate.BASE_URL + "controlnet/detect" + resp = requests.post(url, json=payload) + assert resp.status_code == status + if status != 200: + return + + resp_json = resp.json() + assert "tensor" in resp_json + assert len(resp_json["tensor"]) == len(payload["controlnet_input_images"]) + return resp_json + + +@disable_in_cq +def test_ipadapter_clip_api(): + """Use previously saved CLIP output in ipadapter run.""" + resp = detect_template( + dict( + controlnet_input_images=[realistic_girl_face_img], + controlnet_module="ip-adapter_clip_h", + ) + ) + ipadapter_input = resp["tensor"] + APITestTemplate( + "test_ipadapter_clip_api", + "txt2img", + payload_overrides={}, + unit_overrides={ + "ipadapter_input": ipadapter_input, + "model": get_model("ip-adapter_sd15"), + }, + ).exec()