get back encoder_hidden_states only use text in controlnet inference

pull/39/head
陈书睿 2023-09-06 12:19:48 +08:00
parent 23f44cc242
commit 64a09c5996
2 changed files with 7 additions and 1 deletions

View File

@ -440,6 +440,9 @@ class CNAttnProcessor:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
else: # get back text
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
@ -519,6 +522,9 @@ class CNAttnProcessor2_0:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
else: # get back text
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)

View File

@ -356,7 +356,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.9" "version": "3.9.16"
} }
}, },
"nbformat": 4, "nbformat": 4,