commit
11a9046821
|
|
@ -23,6 +23,8 @@ class AttnProcessor(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -109,6 +111,8 @@ class IPAttnProcessor(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -205,6 +209,8 @@ class AttnProcessor2_0(torch.nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -308,6 +314,8 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -413,7 +421,7 @@ class CNAttnProcessor:
|
|||
def __init__(self, num_tokens=4):
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
|
|
@ -487,6 +495,8 @@ class CNAttnProcessor2_0:
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ class LoRAAttnProcessor(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -130,6 +132,8 @@ class LoRAIPAttnProcessor(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -236,6 +240,8 @@ class LoRAAttnProcessor2_0(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
@ -335,7 +341,7 @@ class LoRAIPAttnProcessor2_0(nn.Module):
|
|||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, *args, **kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue