diff --git a/.eslintrc.json b/.eslintrc.json index 4ca2afbdb..05c467b50 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -103,6 +103,7 @@ // progressbar.js "randomId": "readonly", "requestProgress": "readonly", + "setRefreshInterval": "readonly", // imageviewer.js "modalPrevImage": "readonly", "modalNextImage": "readonly", diff --git a/CHANGELOG.md b/CHANGELOG.md index c481a7b34..d3387a931 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,13 @@ ## Update for 2025-04-05 - **Features** - - Flux: TeaCache for Flux.1 - - Video: add `FasterCache` and `PAB` support to WanDB and LTX models - - ZLUDA: add more GPUs to recognized list - Pipe: [SoftFill](https://github.com/zacheryvaughn/softfill-pipelines) +- **Caching** + - add `TeaCache` support to *Flux, CogVideoX, Mochi, LTX* + - add `FasterCache` support to *WanAI, LTX* (other video models already supported) + - add `PyramidAttentionBroadcast` support to *WanAI, LTX* (other video models already supported) +- **Other** + - ZLUDA: add more GPUs to recognized list select in scripts, available for sdxl in inpaint model - LoRA: add option to force-reload LoRA on every generate - Grid: add of max-rows and max-columns in settings to control grid format diff --git a/javascript/progressBar.js b/javascript/progressBar.js index 801abfffb..a244232d1 100644 --- a/javascript/progressBar.js +++ b/javascript/progressBar.js @@ -1,4 +1,15 @@ let lastState = {}; +let refreshInterval = 10000; + +function setRefreshInterval() { + refreshInterval = opts.live_preview_refresh_period || 500; + log('refreshInterval', document.visibilityState, refreshInterval); + document.addEventListener('visibilitychange', () => { + if (document.hidden) refreshInterval = Math.max(2500, opts.live_preview_refresh_period || 1000); + else refreshInterval = opts.live_preview_refresh_period || 1000; + log('refreshInterval', document.visibilityState, refreshInterval); + }); +} function pad2(x) { return x < 10 ? `0${x}` : x; @@ -152,7 +163,8 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres done(); }; - xhrPost('./internal/progress', { id_task, id_live_preview }, onProgressHandler, onProgressErrorHandler, false, 30000); + const request_id = document.hidden ? -1 : id_live_preview; + xhrPost('./internal/progress', { id_task, request_id }, onProgressHandler, onProgressErrorHandler, false, 30000); }; debug('livePreview start:', dateStart); start(id_task, 0); diff --git a/javascript/startup.js b/javascript/startup.js index 328c167b6..842f925ba 100644 --- a/javascript/startup.js +++ b/javascript/startup.js @@ -35,6 +35,7 @@ async function initStartup() { window.subpath = window.opts.subpath; window.api = `${window.subpath}/sdapi/v1`; } + setRefreshInterval(); executeCallbacks(uiReadyCallbacks); initLogMonitor(); setupExtraNetworks(); diff --git a/modules/model_flux.py b/modules/model_flux.py index 71b5f1e96..7f1ae035d 100644 --- a/modules/model_flux.py +++ b/modules/model_flux.py @@ -225,7 +225,8 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch if shared.opts.teacache_enabled: from modules import teacache - diffusers.FluxTransformer2DModel.forward = teacache.teacache_forward + shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}') + diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # load overrides if any if shared.opts.sd_unet != 'Default': diff --git a/modules/progress.py b/modules/progress.py index 6413e7188..ab1230ff7 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -87,17 +87,22 @@ def api_progress(req: ProgressRequest): id_live_preview = req.id_live_preview live_preview = None textinfo = shared.state.textinfo - updated = shared.state.set_current_image() if not active: id_live_preview = -1 textinfo = "Queued..." if queued else "Waiting..." - debug_log(f'Preview: job={shared.state.job} active={active} progress={step}/{steps}/{progress} image={shared.state.current_image_sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} updated={updated} image={shared.state.current_image} elapsed={elapsed:.3f}') + debug_log(f'Preview: job={shared.state.job} active={active} progress={step}/{steps}/{progress} image={shared.state.current_image_sampling_step} request={id_live_preview} last={shared.state.id_live_preview} enabled={shared.opts.live_previews_enable} job={shared.state.preview_job} elapsed={elapsed:.3f}') - if shared.opts.live_previews_enable and active and (shared.state.id_live_preview != req.id_live_preview) and (shared.state.current_image is not None): - buffered = io.BytesIO() - shared.state.current_image.save(buffered, format='jpeg') - live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}' + print('HERE1', req.id_live_preview, shared.state.id_live_preview) + + if shared.opts.live_previews_enable and active and (req.id_live_preview != -1) (shared.state.id_live_preview != req.id_live_preview) and (shared.state.current_image is not None): + shared.state.set_current_image() + if shared.state.current_image is not None: + buffered = io.BytesIO() + shared.state.current_image.save(buffered, format='jpeg') + live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}' + else: + live_preview = None id_live_preview = shared.state.id_live_preview diff --git a/modules/teacache/__init__.py b/modules/teacache/__init__.py index 84fbd3d0f..cb06c1f26 100644 --- a/modules/teacache/__init__.py +++ b/modules/teacache/__init__.py @@ -1 +1,23 @@ -from .teacache_flux import apply_teacache, teacache_forward +from .teacache_flux import teacache_flux_forward +from .teacache_ltx import teacache_ltx_forward +from .teacache_mochi import teacache_mochi_forward +from .teacache_cogvideox import teacache_cog_forward + + +supported_models = ['Flux', 'CogVideoX', 'Mochi', 'LTX'] + + +def apply_teacache(p): + from modules import shared + if not any(shared.sd_model.__class__.__name__.startswith(x) for x in supported_models): + return + if not hasattr(shared.sd_model, 'transformer'): + return + shared.sd_model.transformer.__class__.enable_teacache = shared.opts.teacache_thresh > 0 + shared.sd_model.transformer.__class__.cnt = 0 + shared.sd_model.transformer.__class__.num_steps = p.steps + shared.sd_model.transformer.__class__.rel_l1_thresh = shared.opts.teacache_thresh # 0.25 for 1.5x speedup, 0.4 for 1.8x speedup, 0.6 for 2.0x speedup, 0.8 for 2.25x speedup + shared.sd_model.transformer.__class__.accumulated_rel_l1_distance = 0 + shared.sd_model.transformer.__class__.previous_modulated_input = None + shared.sd_model.transformer.__class__.previous_residual = None + shared.log.info(f'Transformers cache: type=teacache cls={shared.sd_model.__class__.__name__} thresh={shared.opts.teacache_thresh}') diff --git a/modules/teacache/teacache_cogvideox.py b/modules/teacache/teacache_cogvideox.py new file mode 100644 index 000000000..436338b37 --- /dev/null +++ b/modules/teacache/teacache_cogvideox.py @@ -0,0 +1,182 @@ +from typing import Any, Dict, Optional, Union, Tuple +import torch +import numpy as np +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging +from diffusers.models.modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def teacache_cog_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if self.ofs_embedding is not None: + ofs_emb = self.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = self.ofs_embedding(ofs_emb) + emb = emb + ofs_emb + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + if self.enable_teacache: + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03] + else: + # CogVideoX-5B and CogvideoX1.5-5B + coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = emb + self.cnt += 1 + if self.cnt == self.num_steps: + self.cnt = 0 + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual + encoder_hidden_states += self.previous_residual_encoder + else: + ori_hidden_states = hidden_states.clone() + ori_encoder_hidden_states = encoder_hidden_states.clone() + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + self.previous_residual = hidden_states - ori_hidden_states + self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states + else: + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B and CogvideoX1.5-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 5. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 6. Unpatchify + p = self.config.patch_size + p_t = self.config.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/modules/teacache/teacache_flux.py b/modules/teacache/teacache_flux.py index 8e57341fd..aa750d3f0 100644 --- a/modules/teacache/teacache_flux.py +++ b/modules/teacache/teacache_flux.py @@ -1,14 +1,14 @@ from typing import Any, Dict, Optional, Union -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers import torch import numpy as np +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def teacache_forward( +def teacache_flux_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, @@ -306,17 +306,3 @@ def teacache_forward( return (output,) return Transformer2DModelOutput(sample=output) - - -def apply_teacache(p): - from modules import shared - if not shared.native or not shared.opts.teacache_enabled or not shared.sd_model.__class__.__name__.startswith('Flux'): - return - shared.sd_model.transformer.__class__.enable_teacache = shared.opts.teacache_thresh > 0 - shared.sd_model.transformer.__class__.cnt = 0 - shared.sd_model.transformer.__class__.num_steps = p.steps - shared.sd_model.transformer.__class__.rel_l1_thresh = shared.opts.teacache_thresh # 0.25 for 1.5x speedup, 0.4 for 1.8x speedup, 0.6 for 2.0x speedup, 0.8 for 2.25x speedup - shared.sd_model.transformer.__class__.accumulated_rel_l1_distance = 0 - shared.sd_model.transformer.__class__.previous_modulated_input = None - shared.sd_model.transformer.__class__.previous_residual = None - shared.log.info(f'Transformers cache: type=teacache thresh={shared.opts.teacache_thresh} cls={shared.sd_model.__class__.__name__}') diff --git a/modules/teacache/teacache_ltx.py b/modules/teacache/teacache_ltx.py index f7f9cd83d..8a4e1b392 100644 --- a/modules/teacache/teacache_ltx.py +++ b/modules/teacache/teacache_ltx.py @@ -1,15 +1,14 @@ -""" -source: https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4LTX-Video/teacache_ltx.py -""" - from typing import Any, Dict, Optional, Tuple -import numpy as np import torch +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.utils import is_torch_version, scale_lora_layers, unscale_lora_layers +import numpy as np -def teacache_forward( +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def teacache_ltx_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, @@ -22,104 +21,73 @@ def teacache_forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - batch_size = hidden_states.size(0) - hidden_states = self.proj_in(hidden_states) + batch_size = hidden_states.size(0) + hidden_states = self.proj_in(hidden_states) - temb, embedded_timestep = self.time_embed( - timestep.flatten(), - batch_size=batch_size, - hidden_dtype=hidden_states.dtype, - ) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) - temb = temb.view(batch_size, -1, temb.size(-1)) - embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - if self.enable_teacache: - inp = hidden_states.clone() - temb_ = temb.clone() - inp = self.transformer_blocks[0].norm1(inp) - num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] - ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) - modulated_inp = inp * (1 + scale_msa) + shift_msa - if self.cnt == 0 or self.cnt == self.num_steps-1: + if self.enable_teacache: + inp = hidden_states.clone() + temb_ = temb.clone() + inp = self.transformer_blocks[0].norm1(inp) + num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] + ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + modulated_inp = inp * (1 + scale_msa) + shift_msa + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: should_calc = True self.accumulated_rel_l1_distance = 0 - else: - coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] - rescale_func = np.poly1d(coefficients) - self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) - if self.accumulated_rel_l1_distance < self.rel_l1_thresh: - should_calc = False - else: - should_calc = True - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp - self.cnt += 1 - if self.cnt == self.num_steps: - self.cnt = 0 - - if self.enable_teacache: - if not should_calc: - hidden_states += self.previous_residual - else: - ori_hidden_states = hidden_states.clone() - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - encoder_attention_mask, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - ) - - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - self.previous_residual = hidden_states - ori_hidden_states + self.previous_modulated_input = modulated_inp + self.cnt += 1 + if self.cnt == self.num_steps: + self.cnt = 0 + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual else: + ori_hidden_states = hidden_states.clone() for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -156,12 +124,52 @@ def teacache_forward( hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift + self.previous_residual = hidden_states - ori_hidden_states + else: + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift - output = self.proj_out(hidden_states) + output = self.proj_out(hidden_states) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/modules/teacache/teacache_mochi.py b/modules/teacache/teacache_mochi.py new file mode 100644 index 000000000..e899fb164 --- /dev/null +++ b/modules/teacache/teacache_mochi.py @@ -0,0 +1,157 @@ +from typing import Any, Dict, Optional +import torch +import numpy as np +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers, logging +from diffusers.models.modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def teacache_mochi_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = self.time_embed( + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + if self.enable_teacache: + inp = hidden_states.clone() + temb_ = temb.clone() + modulated_inp, gate_msa, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, temb_) + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.cnt += 1 + if self.cnt == self.num_steps: + self.cnt = 0 + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual + else: + ori_hidden_states = hidden_states.clone() + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = self.norm_out(hidden_states, temb) + self.previous_residual = hidden_states - ori_hidden_states + else: + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = self.norm_out(hidden_states, temb) + + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/modules/video_models/video_cache.py b/modules/video_models/video_cache.py new file mode 100644 index 000000000..01f541be3 --- /dev/null +++ b/modules/video_models/video_cache.py @@ -0,0 +1,16 @@ +import diffusers +from modules import shared + + +def apply_teacache_patch(cls): + if shared.opts.teacache_enabled: + from modules import teacache + shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={cls.__name__}') + if cls.__name__ == 'LTXVideoTransformer3DModel': + cls.forward = teacache.teacache_ltx_forward + elif cls.__name__ == 'MochiTransformer3DModel': + cls.forward = teacache.teacache_mochi_forward + elif cls.__name__ == 'CogVideoXTransformer3DModel': + cls.forward = teacache.teacache_cog_forward + + diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward diff --git a/modules/video_models/video_load.py b/modules/video_models/video_load.py index 13d0dc935..ab2669939 100644 --- a/modules/video_models/video_load.py +++ b/modules/video_models/video_load.py @@ -1,7 +1,7 @@ import os import time from modules import shared, errors, sd_models, sd_checkpoint, model_quant, devices -from modules.video_models import models_def, video_utils, video_vae, video_overrides +from modules.video_models import models_def, video_utils, video_vae, video_overrides, video_cache loaded_model = None @@ -17,6 +17,8 @@ def load_model(selected: models_def.Model): sd_models.unload_model_weights() t0 = time.time() + video_cache.apply_teacache_patch(selected.dit_cls) + # text encoder try: quant_args = model_quant.create_config(module='TE') diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index 93854c365..2f097f707 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -119,13 +119,11 @@ def link_or_copy(src: os.PathLike, dst: os.PathLike): def load(): - global core, ml # pylint: disable=global-statement + global core, ml, hipBLASLt_enabled, MIOpen_enabled # pylint: disable=global-statement core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll'))) ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll'))) - is_nightly = core.get_nightly_flag() == 1 hipBLASLt_enabled = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll")) - global MIOpen_enabled # pylint: disable=global-statement MIOpen_enabled = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll")) for k, v in DLL_MAPPING.items(): diff --git a/scripts/ltxvideo.py b/scripts/ltxvideo.py index 697e64021..abd426993 100644 --- a/scripts/ltxvideo.py +++ b/scripts/ltxvideo.py @@ -5,7 +5,6 @@ import gradio as gr import diffusers import transformers from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant, timer -from modules.teacache.teacache_ltx import teacache_forward repos = { @@ -113,7 +112,6 @@ class Script(scripts.Script): if shared.sd_model.__class__ != cls: sd_models.unload_model_weights() kwargs = model_quant.create_config() - diffusers.LTXVideoTransformer3DModel.forward = teacache_forward if os.path.isfile(repo_id): shared.sd_model = cls.from_single_file( repo_id,