diff --git a/annotator/clipvision/__init__.py b/annotator/clipvision/__init__.py index 6865501..b066438 100644 --- a/annotator/clipvision/__init__.py +++ b/annotator/clipvision/__init__.py @@ -1,7 +1,9 @@ import os import cv2 import torch +import numpy as np +from einops import rearrange from modules import devices from annotator.annotator_path import models_path from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor @@ -126,11 +128,20 @@ class ClipVisionDetector: if self.model is not None: self.model.to('meta') - def __call__(self, input_image): + def __call__(self, input_image: np.ndarray): + assert isinstance(input_image, np.ndarray) with torch.no_grad(): + mask = None input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA) + if input_image.shape[2] == 4: # Has alpha channel. + mask = 255 - input_image[:, :, 3:4] # Invert mask + input_image = input_image[:, :, :3] feat = self.processor(images=input_image, return_tensors="pt") feat['pixel_values'] = feat['pixel_values'].to(self.device) + # Apply CLIP mask. + if mask is not None: + mask_tensor = torch.from_numpy(mask).to(self.device).float() / 255.0 + feat['pixel_values'] *= rearrange(mask_tensor, "h w c -> 1 c h w") result = self.model(**feat, output_hidden_states=True) result['hidden_states'] = [v.to(self.device) for v in result['hidden_states']] result = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in result.items()} diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py index aad9d09..21a9306 100644 --- a/internal_controlnet/external_code.py +++ b/internal_controlnet/external_code.py @@ -226,6 +226,18 @@ class ControlNetUnit: def is_animate_diff_batch(self) -> bool: return getattr(self, "animatediff_batch", False) + @property + def uses_clip(self) -> bool: + """Whether this unit uses clip preprocessor.""" + return any(( + ("ip-adapter" in self.module and "faceid" not in self.module), + self.module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"), + )) + + @property + def is_inpaint(self) -> bool: + return "inpaint" in self.module + def to_base64_nparray(encoding: str): """ @@ -349,9 +361,9 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe mask = unit['mask'] del unit['mask'] - if "image_mask" in unit: - mask = unit["image_mask"] - del unit["image_mask"] + if "mask_image" in unit: + mask = unit["mask_image"] + del unit["mask_image"] if 'image' in unit and not isinstance(unit['image'], dict): unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[ diff --git a/scripts/controlnet.py b/scripts/controlnet.py index e38655d..2290857 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -647,7 +647,7 @@ class Script(scripts.Script, metaclass=( for idx in range(len(input_image)): while len(image[idx]['mask'].shape) < 3: image[idx]['mask'] = image[idx]['mask'][..., np.newaxis] - if 'inpaint' in unit.module: + if unit.is_inpaint or unit.uses_clip: color = HWC3(image[idx]["image"]) alpha = image[idx]['mask'][:, :, 0:1] input_image[idx] = np.concatenate([color, alpha], axis=2) @@ -656,8 +656,8 @@ class Script(scripts.Script, metaclass=( if 'mask' in image and image['mask'] is not None: while len(image['mask'].shape) < 3: image['mask'] = image['mask'][..., np.newaxis] - if 'inpaint' in unit.module: - logger.info("using inpaint as input") + if unit.is_inpaint or unit.uses_clip: + logger.info("using mask") color = HWC3(image['image']) alpha = image['mask'][:, :, 0:1] input_image = np.concatenate([color, alpha], axis=2) @@ -681,7 +681,7 @@ class Script(scripts.Script, metaclass=( resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode) a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None) - if 'inpaint' in unit.module: + if unit.is_inpaint: if a1111_mask_image is not None: a1111_mask = np.array(prepare_mask(a1111_mask_image, p)) assert a1111_mask.ndim == 2 diff --git a/scripts/processor.py b/scripts/processor.py index 068504f..d551759 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -383,7 +383,6 @@ clip_encoder = { def clip(img, res=512, config='clip_vitl', low_vram=False, **kwargs): - img = HWC3(img) global clip_encoder if clip_encoder[config] is None: from annotator.clipvision import ClipVisionDetector diff --git a/tests/web_api/clip_mask_test.py b/tests/web_api/clip_mask_test.py new file mode 100644 index 0000000..7393a08 --- /dev/null +++ b/tests/web_api/clip_mask_test.py @@ -0,0 +1,55 @@ +from .template import ( + APITestTemplate, + girl_img, + mask_img, + disable_in_cq, + get_model, +) + + +@disable_in_cq +def test_clip_mask_txt2img_control(): + """No mask control group.""" + assert APITestTemplate( + "test_clip_mask_txt2img_control", + "txt2img", + payload_overrides={}, + unit_overrides={ + "module": "ip-adapter-auto", + "model": get_model("ip-adapter_sd15"), + "image": girl_img, + }, + ).exec() + + +@disable_in_cq +def test_clip_mask_txt2img_experiment(): + """With mask experiment group.""" + assert APITestTemplate( + "test_clip_mask_txt2img_experiment", + "txt2img", + payload_overrides={}, + unit_overrides={ + "module": "ip-adapter-auto", + "model": get_model("ip-adapter_sd15"), + "image": girl_img, + "mask_image": mask_img, + }, + ).exec() + + +@disable_in_cq +def test_clip_mask_img2img(): + """CLIP mask should not work in img2img inpaint.""" + assert APITestTemplate( + "test_clip_mask_img2img", + "img2img", + payload_overrides={ + "init_images": [girl_img], + "mask": mask_img, + }, + unit_overrides={ + "module": "ip-adapter-auto", + "model": get_model("ip-adapter_sd15"), + }, + ).exec()