add teacache

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3866/head
Vladimir Mandic 2025-04-05 20:15:48 -04:00
parent fcde9f406d
commit 228da75a2e
15 changed files with 524 additions and 132 deletions

View File

@ -103,6 +103,7 @@
// progressbar.js
"randomId": "readonly",
"requestProgress": "readonly",
"setRefreshInterval": "readonly",
// imageviewer.js
"modalPrevImage": "readonly",
"modalNextImage": "readonly",

View File

@ -3,10 +3,13 @@
## Update for 2025-04-05
- **Features**
- Flux: TeaCache for Flux.1
- Video: add `FasterCache` and `PAB` support to WanDB and LTX models
- ZLUDA: add more GPUs to recognized list
- Pipe: [SoftFill](https://github.com/zacheryvaughn/softfill-pipelines)
- **Caching**
- add `TeaCache` support to *Flux, CogVideoX, Mochi, LTX*
- add `FasterCache` support to *WanAI, LTX* (other video models already supported)
- add `PyramidAttentionBroadcast` support to *WanAI, LTX* (other video models already supported)
- **Other**
- ZLUDA: add more GPUs to recognized list
select in scripts, available for sdxl in inpaint model
- LoRA: add option to force-reload LoRA on every generate
- Grid: add of max-rows and max-columns in settings to control grid format

View File

@ -1,4 +1,15 @@
let lastState = {};
let refreshInterval = 10000;
function setRefreshInterval() {
refreshInterval = opts.live_preview_refresh_period || 500;
log('refreshInterval', document.visibilityState, refreshInterval);
document.addEventListener('visibilitychange', () => {
if (document.hidden) refreshInterval = Math.max(2500, opts.live_preview_refresh_period || 1000);
else refreshInterval = opts.live_preview_refresh_period || 1000;
log('refreshInterval', document.visibilityState, refreshInterval);
});
}
function pad2(x) {
return x < 10 ? `0${x}` : x;
@ -152,7 +163,8 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
done();
};
xhrPost('./internal/progress', { id_task, id_live_preview }, onProgressHandler, onProgressErrorHandler, false, 30000);
const request_id = document.hidden ? -1 : id_live_preview;
xhrPost('./internal/progress', { id_task, request_id }, onProgressHandler, onProgressErrorHandler, false, 30000);
};
debug('livePreview start:', dateStart);
start(id_task, 0);

View File

@ -35,6 +35,7 @@ async function initStartup() {
window.subpath = window.opts.subpath;
window.api = `${window.subpath}/sdapi/v1`;
}
setRefreshInterval();
executeCallbacks(uiReadyCallbacks);
initLogMonitor();
setupExtraNetworks();

View File

@ -225,7 +225,8 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
if shared.opts.teacache_enabled:
from modules import teacache
diffusers.FluxTransformer2DModel.forward = teacache.teacache_forward
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward
# load overrides if any
if shared.opts.sd_unet != 'Default':

View File

@ -87,17 +87,22 @@ def api_progress(req: ProgressRequest):
id_live_preview = req.id_live_preview
live_preview = None
textinfo = shared.state.textinfo
updated = shared.state.set_current_image()
if not active:
id_live_preview = -1
textinfo = "Queued..." if queued else "Waiting..."
debug_log(f'Preview: job={shared.state.job} active={active} progress={step}/{steps}/{progress} image={shared.state.current_image_sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} updated={updated} image={shared.state.current_image} elapsed={elapsed:.3f}')
debug_log(f'Preview: job={shared.state.job} active={active} progress={step}/{steps}/{progress} image={shared.state.current_image_sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} elapsed={elapsed:.3f}')
if shared.opts.live_previews_enable and active and (shared.state.id_live_preview != req.id_live_preview) and (shared.state.current_image is not None):
buffered = io.BytesIO()
shared.state.current_image.save(buffered, format='jpeg')
live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}'
print('HERE1', req.id_live_preview, shared.state.id_live_preview)
if shared.opts.live_previews_enable and active and (req.id_live_preview != -1) (shared.state.id_live_preview != req.id_live_preview) and (shared.state.current_image is not None):
shared.state.set_current_image()
if shared.state.current_image is not None:
buffered = io.BytesIO()
shared.state.current_image.save(buffered, format='jpeg')
live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}'
else:
live_preview = None
id_live_preview = shared.state.id_live_preview

View File

@ -1 +1,23 @@
from .teacache_flux import apply_teacache, teacache_forward
from .teacache_flux import teacache_flux_forward
from .teacache_ltx import teacache_ltx_forward
from .teacache_mochi import teacache_mochi_forward
from .teacache_cogvideox import teacache_cog_forward
supported_models = ['Flux', 'CogVideoX', 'Mochi', 'LTX']
def apply_teacache(p):
from modules import shared
if not any(shared.sd_model.__class__.__name__.startswith(x) for x in supported_models):
return
if not hasattr(shared.sd_model, 'transformer'):
return
shared.sd_model.transformer.__class__.enable_teacache = shared.opts.teacache_thresh > 0
shared.sd_model.transformer.__class__.cnt = 0
shared.sd_model.transformer.__class__.num_steps = p.steps
shared.sd_model.transformer.__class__.rel_l1_thresh = shared.opts.teacache_thresh # 0.25 for 1.5x speedup, 0.4 for 1.8x speedup, 0.6 for 2.0x speedup, 0.8 for 2.25x speedup
shared.sd_model.transformer.__class__.accumulated_rel_l1_distance = 0
shared.sd_model.transformer.__class__.previous_modulated_input = None
shared.sd_model.transformer.__class__.previous_residual = None
shared.log.info(f'Transformers cache: type=teacache cls={shared.sd_model.__class__.__name__} thresh={shared.opts.teacache_thresh}')

View File

@ -0,0 +1,182 @@
from typing import Any, Dict, Optional, Union, Tuple
import torch
import numpy as np
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging
from diffusers.models.modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_cog_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None:
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if self.enable_teacache:
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
else:
# CogVideoX-5B and CogvideoX1.5-5B
coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = emb
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
encoder_hidden_states += self.previous_residual_encoder
else:
ori_hidden_states = hidden_states.clone()
ori_encoder_hidden_states = encoder_hidden_states.clone()
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
self.previous_residual = hidden_states - ori_hidden_states
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
else:
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B and CogvideoX1.5-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
p_t = self.config.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -1,14 +1,14 @@
from typing import Any, Dict, Optional, Union
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
import torch
import numpy as np
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward(
def teacache_flux_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
@ -306,17 +306,3 @@ def teacache_forward(
return (output,)
return Transformer2DModelOutput(sample=output)
def apply_teacache(p):
from modules import shared
if not shared.native or not shared.opts.teacache_enabled or not shared.sd_model.__class__.__name__.startswith('Flux'):
return
shared.sd_model.transformer.__class__.enable_teacache = shared.opts.teacache_thresh > 0
shared.sd_model.transformer.__class__.cnt = 0
shared.sd_model.transformer.__class__.num_steps = p.steps
shared.sd_model.transformer.__class__.rel_l1_thresh = shared.opts.teacache_thresh # 0.25 for 1.5x speedup, 0.4 for 1.8x speedup, 0.6 for 2.0x speedup, 0.8 for 2.25x speedup
shared.sd_model.transformer.__class__.accumulated_rel_l1_distance = 0
shared.sd_model.transformer.__class__.previous_modulated_input = None
shared.sd_model.transformer.__class__.previous_residual = None
shared.log.info(f'Transformers cache: type=teacache thresh={shared.opts.teacache_thresh} cls={shared.sd_model.__class__.__name__}')

View File

@ -1,15 +1,14 @@
"""
source: https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4LTX-Video/teacache_ltx.py
"""
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import is_torch_version, scale_lora_layers, unscale_lora_layers
import numpy as np
def teacache_forward(
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_ltx_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
@ -22,104 +21,73 @@ def teacache_forward(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
inp = self.transformer_blocks[0].norm1(inp)
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
modulated_inp = inp * (1 + scale_msa) + shift_msa
if self.cnt == 0 or self.cnt == self.num_steps-1:
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
inp = self.transformer_blocks[0].norm1(inp)
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
modulated_inp = inp * (1 + scale_msa) + shift_msa
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
self.previous_residual = hidden_states - ori_hidden_states
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
@ -156,12 +124,52 @@ def teacache_forward(
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
self.previous_residual = hidden_states - ori_hidden_states
else:
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -0,0 +1,157 @@
from typing import Any, Dict, Optional
import torch
import numpy as np
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging
from diffusers.models.modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_mochi_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep,
encoder_hidden_states,
encoder_attention_mask,
hidden_dtype=hidden_states.dtype,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
image_rotary_emb = self.rope(
self.pos_frequencies,
num_frames,
post_patch_height,
post_patch_width,
device=hidden_states.device,
dtype=torch.float32,
)
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, temb_)
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
self.previous_residual = hidden_states - ori_hidden_states
else:
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@ -0,0 +1,16 @@
import diffusers
from modules import shared
def apply_teacache_patch(cls):
if shared.opts.teacache_enabled:
from modules import teacache
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={cls.__name__}')
if cls.__name__ == 'LTXVideoTransformer3DModel':
cls.forward = teacache.teacache_ltx_forward
elif cls.__name__ == 'MochiTransformer3DModel':
cls.forward = teacache.teacache_mochi_forward
elif cls.__name__ == 'CogVideoXTransformer3DModel':
cls.forward = teacache.teacache_cog_forward
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward

View File

@ -1,7 +1,7 @@
import os
import time
from modules import shared, errors, sd_models, sd_checkpoint, model_quant, devices
from modules.video_models import models_def, video_utils, video_vae, video_overrides
from modules.video_models import models_def, video_utils, video_vae, video_overrides, video_cache
loaded_model = None
@ -17,6 +17,8 @@ def load_model(selected: models_def.Model):
sd_models.unload_model_weights()
t0 = time.time()
video_cache.apply_teacache_patch(selected.dit_cls)
# text encoder
try:
quant_args = model_quant.create_config(module='TE')

View File

@ -119,13 +119,11 @@ def link_or_copy(src: os.PathLike, dst: os.PathLike):
def load():
global core, ml # pylint: disable=global-statement
global core, ml, hipBLASLt_enabled, MIOpen_enabled # pylint: disable=global-statement
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
is_nightly = core.get_nightly_flag() == 1
hipBLASLt_enabled = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
global MIOpen_enabled # pylint: disable=global-statement
MIOpen_enabled = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
for k, v in DLL_MAPPING.items():

View File

@ -5,7 +5,6 @@ import gradio as gr
import diffusers
import transformers
from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant, timer
from modules.teacache.teacache_ltx import teacache_forward
repos = {
@ -113,7 +112,6 @@ class Script(scripts.Script):
if shared.sd_model.__class__ != cls:
sd_models.unload_model_weights()
kwargs = model_quant.create_config()
diffusers.LTXVideoTransformer3DModel.forward = teacache_forward
if os.path.isfile(repo_id):
shared.sd_model = cls.from_single_file(
repo_id,