for the new diffuser version

pull/338/head
Hu Ye 2024-04-16 18:45:47 +08:00 committed by GitHub
parent 685b550ed2
commit 8e1a977274
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 1 deletions

View File

@ -23,6 +23,8 @@ class AttnProcessor(nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
@ -109,6 +111,8 @@ class IPAttnProcessor(nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
@ -205,6 +209,8 @@ class AttnProcessor2_0(torch.nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
@ -308,6 +314,8 @@ class IPAttnProcessor2_0(torch.nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
@ -413,7 +421,7 @@ class CNAttnProcessor:
def __init__(self, num_tokens=4):
self.num_tokens = num_tokens
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
residual = hidden_states
if attn.spatial_norm is not None:
@ -487,6 +495,8 @@ class CNAttnProcessor2_0:
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states