parent
807a883ed6
commit
a5f9d59717
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from einops import rearrange
|
||||
from modules import devices
|
||||
from annotator.annotator_path import models_path
|
||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
|
||||
|
|
@ -126,11 +128,20 @@ class ClipVisionDetector:
|
|||
if self.model is not None:
|
||||
self.model.to('meta')
|
||||
|
||||
def __call__(self, input_image):
|
||||
def __call__(self, input_image: np.ndarray):
|
||||
assert isinstance(input_image, np.ndarray)
|
||||
with torch.no_grad():
|
||||
mask = None
|
||||
input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA)
|
||||
if input_image.shape[2] == 4: # Has alpha channel.
|
||||
mask = 255 - input_image[:, :, 3:4] # Invert mask
|
||||
input_image = input_image[:, :, :3]
|
||||
feat = self.processor(images=input_image, return_tensors="pt")
|
||||
feat['pixel_values'] = feat['pixel_values'].to(self.device)
|
||||
# Apply CLIP mask.
|
||||
if mask is not None:
|
||||
mask_tensor = torch.from_numpy(mask).to(self.device).float() / 255.0
|
||||
feat['pixel_values'] *= rearrange(mask_tensor, "h w c -> 1 c h w")
|
||||
result = self.model(**feat, output_hidden_states=True)
|
||||
result['hidden_states'] = [v.to(self.device) for v in result['hidden_states']]
|
||||
result = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
|
||||
|
|
|
|||
|
|
@ -226,6 +226,18 @@ class ControlNetUnit:
|
|||
def is_animate_diff_batch(self) -> bool:
|
||||
return getattr(self, "animatediff_batch", False)
|
||||
|
||||
@property
|
||||
def uses_clip(self) -> bool:
|
||||
"""Whether this unit uses clip preprocessor."""
|
||||
return any((
|
||||
("ip-adapter" in self.module and "faceid" not in self.module),
|
||||
self.module in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"),
|
||||
))
|
||||
|
||||
@property
|
||||
def is_inpaint(self) -> bool:
|
||||
return "inpaint" in self.module
|
||||
|
||||
|
||||
def to_base64_nparray(encoding: str):
|
||||
"""
|
||||
|
|
@ -349,9 +361,9 @@ def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNe
|
|||
mask = unit['mask']
|
||||
del unit['mask']
|
||||
|
||||
if "image_mask" in unit:
|
||||
mask = unit["image_mask"]
|
||||
del unit["image_mask"]
|
||||
if "mask_image" in unit:
|
||||
mask = unit["mask_image"]
|
||||
del unit["mask_image"]
|
||||
|
||||
if 'image' in unit and not isinstance(unit['image'], dict):
|
||||
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[
|
||||
|
|
|
|||
|
|
@ -647,7 +647,7 @@ class Script(scripts.Script, metaclass=(
|
|||
for idx in range(len(input_image)):
|
||||
while len(image[idx]['mask'].shape) < 3:
|
||||
image[idx]['mask'] = image[idx]['mask'][..., np.newaxis]
|
||||
if 'inpaint' in unit.module:
|
||||
if unit.is_inpaint or unit.uses_clip:
|
||||
color = HWC3(image[idx]["image"])
|
||||
alpha = image[idx]['mask'][:, :, 0:1]
|
||||
input_image[idx] = np.concatenate([color, alpha], axis=2)
|
||||
|
|
@ -656,8 +656,8 @@ class Script(scripts.Script, metaclass=(
|
|||
if 'mask' in image and image['mask'] is not None:
|
||||
while len(image['mask'].shape) < 3:
|
||||
image['mask'] = image['mask'][..., np.newaxis]
|
||||
if 'inpaint' in unit.module:
|
||||
logger.info("using inpaint as input")
|
||||
if unit.is_inpaint or unit.uses_clip:
|
||||
logger.info("using mask")
|
||||
color = HWC3(image['image'])
|
||||
alpha = image['mask'][:, :, 0:1]
|
||||
input_image = np.concatenate([color, alpha], axis=2)
|
||||
|
|
@ -681,7 +681,7 @@ class Script(scripts.Script, metaclass=(
|
|||
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
|
||||
|
||||
a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
|
||||
if 'inpaint' in unit.module:
|
||||
if unit.is_inpaint:
|
||||
if a1111_mask_image is not None:
|
||||
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
|
||||
assert a1111_mask.ndim == 2
|
||||
|
|
|
|||
|
|
@ -383,7 +383,6 @@ clip_encoder = {
|
|||
|
||||
|
||||
def clip(img, res=512, config='clip_vitl', low_vram=False, **kwargs):
|
||||
img = HWC3(img)
|
||||
global clip_encoder
|
||||
if clip_encoder[config] is None:
|
||||
from annotator.clipvision import ClipVisionDetector
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
from .template import (
|
||||
APITestTemplate,
|
||||
girl_img,
|
||||
mask_img,
|
||||
disable_in_cq,
|
||||
get_model,
|
||||
)
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_clip_mask_txt2img_control():
|
||||
"""No mask control group."""
|
||||
assert APITestTemplate(
|
||||
"test_clip_mask_txt2img_control",
|
||||
"txt2img",
|
||||
payload_overrides={},
|
||||
unit_overrides={
|
||||
"module": "ip-adapter-auto",
|
||||
"model": get_model("ip-adapter_sd15"),
|
||||
"image": girl_img,
|
||||
},
|
||||
).exec()
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_clip_mask_txt2img_experiment():
|
||||
"""With mask experiment group."""
|
||||
assert APITestTemplate(
|
||||
"test_clip_mask_txt2img_experiment",
|
||||
"txt2img",
|
||||
payload_overrides={},
|
||||
unit_overrides={
|
||||
"module": "ip-adapter-auto",
|
||||
"model": get_model("ip-adapter_sd15"),
|
||||
"image": girl_img,
|
||||
"mask_image": mask_img,
|
||||
},
|
||||
).exec()
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_clip_mask_img2img():
|
||||
"""CLIP mask should not work in img2img inpaint."""
|
||||
assert APITestTemplate(
|
||||
"test_clip_mask_img2img",
|
||||
"img2img",
|
||||
payload_overrides={
|
||||
"init_images": [girl_img],
|
||||
"mask": mask_img,
|
||||
},
|
||||
unit_overrides={
|
||||
"module": "ip-adapter-auto",
|
||||
"model": get_model("ip-adapter_sd15"),
|
||||
},
|
||||
).exec()
|
||||
Loading…
Reference in New Issue