diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 07fd6d8..93592bb 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -23,6 +23,8 @@ class AttnProcessor(nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -109,6 +111,8 @@ class IPAttnProcessor(nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -205,6 +209,8 @@ class AttnProcessor2_0(torch.nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -308,6 +314,8 @@ class IPAttnProcessor2_0(torch.nn.Module): encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states @@ -413,7 +421,7 @@ class CNAttnProcessor: def __init__(self, num_tokens=4): self.num_tokens = num_tokens - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,): residual = hidden_states if attn.spatial_norm is not None: @@ -487,6 +495,8 @@ class CNAttnProcessor2_0: encoder_hidden_states=None, attention_mask=None, temb=None, + *args, + **kwargs, ): residual = hidden_states