More robust ipadapter detection (#2936)

* More robust ipadapter detection

* Remove dep
pull/2938/head
Chenlei Hu 2024-05-27 11:26:28 -04:00 committed by GitHub
parent ba3984a022
commit 97c8598c8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 13 deletions

View File

@ -228,21 +228,12 @@ class ControlNetUnit(BaseModel):
animatediff_batch: bool = False
batch_modifiers: list = []
batch_image_files: list = []
batch_keyframe_idx: Optional[str|list] = None
batch_keyframe_idx: Optional[str | list] = None
@property
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter-auto",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"ip-adapter_pulid",
"instant_id_face_embedding",
)
return self.is_ipadapter
@property
def is_animate_diff_batch(self) -> bool:
@ -263,6 +254,13 @@ class ControlNetUnit(BaseModel):
def is_inpaint(self) -> bool:
return "inpaint" in self.module
@property
def is_ipadapter(self) -> bool:
p = ControlNetUnit.cls_get_preprocessor(self.module)
if p is None:
return False
return "IP-Adapter" in p.tags
def get_actual_preprocessors(self) -> List[Any]:
p = ControlNetUnit.cls_get_preprocessor(self.module)
# Map "ip-adapter-auto" to actual preprocessor.

View File

@ -948,7 +948,7 @@ class Script(scripts.Script, metaclass=(
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
if unit.accepts_multiple_inputs:
if unit.is_ipadapter:
ip_adapter_image_emb_cond = []
model_net.ipadapter.image_proj_model.to(torch.float32) # noqa
for c in cc:
@ -975,7 +975,7 @@ class Script(scripts.Script, metaclass=(
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
logger.info(f"\t{frame_idx}: {frame_path}")
c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu()
elif unit.accepts_multiple_inputs:
elif unit.is_ipadapter:
# ip-adapter should do prompt travel
logger.info("IP-Adapter: control prompts will be traveled in the following way:")
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):