Add clip mask support (#2721)

* Add clip mask support

* nit

* Add tests
pull/2724/head
Chenlei Hu 2024-03-31 16:13:51 +00:00 committed by GitHub
parent 807a883ed6
commit a5f9d59717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 86 additions and 9 deletions

View File

@ -1,7 +1,9 @@
import os
import cv2
import torch
import numpy as np
from einops import rearrange
from modules import devices
from annotator.annotator_path import models_path
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
@ -126,11 +128,20 @@ class ClipVisionDetector:
if self.model is not None:
self.model.to('meta')
def __call__(self, input_image):
def __call__(self, input_image: np.ndarray):
assert isinstance(input_image, np.ndarray)
with torch.no_grad():
mask = None
input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA)
if input_image.shape[2] == 4: # Has alpha channel.
mask = 255 - input_image[:, :, 3:4] # Invert mask
input_image = input_image[:, :, :3]
feat = self.processor(images=input_image, return_tensors="pt")
feat['pixel_values'] = feat['pixel_values'].to(self.device)
# Apply CLIP mask.
if mask is not None:
mask_tensor = torch.from_numpy(mask).to(self.device).float() / 255.0
feat['pixel_values'] *= rearrange(mask_tensor, "h w c -> 1 c h w")
result = self.model(**feat, output_hidden_states=True)
result['hidden_states'] = [v.to(self.device) for v in result['hidden_states']]
result = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in result.items()}

View File

@ -226,6 +226,18 @@ class ControlNetUnit:
def is_animate_diff_batch(self) -> bool:
return getattr(self, "animatediff_batch", False)
@property
def uses_clip(self) -> bool:
"""Whether this unit uses clip preprocessor."""
return any((
("ip-adapter" in self.module and "faceid" not in self.module),
self.module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
))
@property
def is_inpaint(self) -> bool:
return "inpaint" in self.module
def to_base64_nparray(encoding: str):
"""
@ -349,9 +361,9 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
mask = unit['mask']
del unit['mask']
if "image_mask" in unit:
mask = unit["image_mask"]
del unit["image_mask"]
if "mask_image" in unit:
mask = unit["mask_image"]
del unit["mask_image"]
if 'image' in unit and not isinstance(unit['image'], dict):
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[

View File

@ -647,7 +647,7 @@ class Script(scripts.Script, metaclass=(
for idx in range(len(input_image)):
while len(image[idx]['mask'].shape) < 3:
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
if 'inpaint' in unit.module:
if unit.is_inpaint or unit.uses_clip:
color = HWC3(image[idx]["image"])
alpha = image[idx]['mask'][:, :, 0:1]
input_image[idx] = np.concatenate([color, alpha], axis=2)
@ -656,8 +656,8 @@ class Script(scripts.Script, metaclass=(
if 'mask' in image and image['mask'] is not None:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]
if 'inpaint' in unit.module:
logger.info("using inpaint as input")
if unit.is_inpaint or unit.uses_clip:
logger.info("using mask")
color = HWC3(image['image'])
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
@ -681,7 +681,7 @@ class Script(scripts.Script, metaclass=(
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module:
if unit.is_inpaint:
if a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
assert a1111_mask.ndim == 2

View File

@ -383,7 +383,6 @@ clip_encoder = {
def clip(img, res=512, config='clip_vitl', low_vram=False, **kwargs):
img = HWC3(img)
global clip_encoder
if clip_encoder[config] is None:
from annotator.clipvision import ClipVisionDetector

View File

@ -0,0 +1,55 @@
from .template import (
APITestTemplate,
girl_img,
mask_img,
disable_in_cq,
get_model,
)
@disable_in_cq
def test_clip_mask_txt2img_control():
"""No mask control group."""
assert APITestTemplate(
"test_clip_mask_txt2img_control",
"txt2img",
payload_overrides={},
unit_overrides={
"module": "ip-adapter-auto",
"model": get_model("ip-adapter_sd15"),
"image": girl_img,
},
).exec()
@disable_in_cq
def test_clip_mask_txt2img_experiment():
"""With mask experiment group."""
assert APITestTemplate(
"test_clip_mask_txt2img_experiment",
"txt2img",
payload_overrides={},
unit_overrides={
"module": "ip-adapter-auto",
"model": get_model("ip-adapter_sd15"),
"image": girl_img,
"mask_image": mask_img,
},
).exec()
@disable_in_cq
def test_clip_mask_img2img():
"""CLIP mask should not work in img2img inpaint."""
assert APITestTemplate(
"test_clip_mask_img2img",
"img2img",
payload_overrides={
"init_images": [girl_img],
"mask": mask_img,
},
unit_overrides={
"module": "ip-adapter-auto",
"model": get_model("ip-adapter_sd15"),
},
).exec()