get back encoder_hidden_states only use text in controlnet inference
parent
23f44cc242
commit
64a09c5996
|
|
@ -440,6 +440,9 @@ class CNAttnProcessor:
|
||||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
else: # get back text
|
||||||
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
value = attn.to_v(encoder_hidden_states)
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
@ -519,6 +522,9 @@ class CNAttnProcessor2_0:
|
||||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
else: # get back text
|
||||||
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
value = attn.to_v(encoder_hidden_states)
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
|
||||||
|
|
@ -356,7 +356,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.9"
|
"version": "3.9.16"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue