diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index 3978c6c..7254251 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -79,7 +79,7 @@ class IPAdapter: scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) if hasattr(self.pipe, "controlnet"): - self.pipe.controlnet.set_attn_processor(CNAttnProcessor()) + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens= self.num_tokens)) def load_ip_adapter(self): state_dict = torch.load(self.ip_ckpt, map_location="cpu")