From 8e1a9772744c8dfdba90b82fd9dc76694e9b2c0c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 16 Apr 2024 18:45:47 +0800 Subject: [PATCH] for the new diffuser version --- ip_adapter/attention_processor.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 07fd6d8..93592bb 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -23,6 +23,8 @@ class AttnProcessor(nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -109,6 +111,8 @@ class IPAttnProcessor(nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -205,6 +209,8 @@ class AttnProcessor2_0(torch.nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -308,6 +314,8 @@ class IPAttnProcessor2_0(torch.nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -413,7 +421,7 @@ class CNAttnProcessor: def __init__(self, num_tokens=4): self.num_tokens = num_tokens - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,): residual = hidden_states if attn.spatial_norm is not None: @@ -487,6 +495,8 @@ class CNAttnProcessor2_0: encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states