mirror of https://github.com/vladmandic/automatic
183 lines
7.4 KiB
Python
183 lines
7.4 KiB
Python
from typing import Any, Dict, Optional, Union, Tuple
|
|
import torch
|
|
import numpy as np
|
|
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
def teacache_cog_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
timestep: Union[int, float, torch.LongTensor],
|
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
return_dict: bool = True,
|
|
):
|
|
if attention_kwargs is not None:
|
|
attention_kwargs = attention_kwargs.copy()
|
|
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
else:
|
|
lora_scale = 1.0
|
|
|
|
if USE_PEFT_BACKEND:
|
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
scale_lora_layers(self, lora_scale)
|
|
else:
|
|
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
|
logger.warning(
|
|
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
)
|
|
|
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
|
|
|
# 1. Time embedding
|
|
timesteps = timestep
|
|
t_emb = self.time_proj(timesteps)
|
|
|
|
# timesteps does not contain any weights and will always return f32 tensors
|
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
|
# there might be better ways to encapsulate this.
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
|
emb = self.time_embedding(t_emb, timestep_cond)
|
|
|
|
if self.ofs_embedding is not None:
|
|
ofs_emb = self.ofs_proj(ofs)
|
|
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
|
ofs_emb = self.ofs_embedding(ofs_emb)
|
|
emb = emb + ofs_emb
|
|
|
|
# 2. Patch embedding
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
|
hidden_states = self.embedding_dropout(hidden_states)
|
|
|
|
text_seq_length = encoder_hidden_states.shape[1]
|
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
|
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
if self.enable_teacache:
|
|
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
# CogVideoX-2B
|
|
coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
|
|
else:
|
|
# CogVideoX-5B and CogvideoX1.5-5B
|
|
coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = emb
|
|
self.cnt += 1
|
|
if self.cnt == self.num_steps:
|
|
self.cnt = 0
|
|
|
|
if self.enable_teacache:
|
|
if not should_calc:
|
|
hidden_states += self.previous_residual
|
|
encoder_hidden_states += self.previous_residual_encoder
|
|
else:
|
|
ori_hidden_states = hidden_states.clone()
|
|
ori_encoder_hidden_states = encoder_hidden_states.clone()
|
|
# 4. Transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
emb,
|
|
image_rotary_emb,
|
|
**ckpt_kwargs,
|
|
)
|
|
else:
|
|
hidden_states, encoder_hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=emb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
|
|
self.previous_residual = hidden_states - ori_hidden_states
|
|
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
|
else:
|
|
# 4. Transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
emb,
|
|
image_rotary_emb,
|
|
**ckpt_kwargs,
|
|
)
|
|
else:
|
|
hidden_states, encoder_hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=emb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
# CogVideoX-2B
|
|
hidden_states = self.norm_final(hidden_states)
|
|
else:
|
|
# CogVideoX-5B and CogvideoX1.5-5B
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
hidden_states = self.norm_final(hidden_states)
|
|
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
# 5. Final block
|
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
# 6. Unpatchify
|
|
p = self.config.patch_size
|
|
p_t = self.config.patch_size_t
|
|
|
|
if p_t is None:
|
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
else:
|
|
output = hidden_states.reshape(
|
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
|
)
|
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
|
|
|
if USE_PEFT_BACKEND:
|
|
# remove `lora_scale` from each PEFT layer
|
|
unscale_lora_layers(self, lora_scale)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
return Transformer2DModelOutput(sample=output)
|