diff --git a/.pylintrc b/.pylintrc index 2c9bbc0c2..5ecbbef48 100644 --- a/.pylintrc +++ b/.pylintrc @@ -27,6 +27,7 @@ ignore-paths=/usr/lib/.*$, modules/meissonic, modules/mod, modules/omnigen, + modules/omnigen2, modules/onnx_impl, modules/pag, modules/pixelsmith, diff --git a/.ruff.toml b/.ruff.toml index ab5e0601c..023678c33 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -20,6 +20,7 @@ exclude = [ "modules/meissonic", "modules/mod", "modules/omnigen", + "modules/omnigen2", "modules/hidream", "modules/pag", "modules/pixelsmith", diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ed2c89c0..6c11b9be5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ - **Models** - [Models Wiki page](https://vladmandic.github.io/sdnext-docs/Models/) is updated will all new models + *note* all new image models larger than 30GB, so [offloading](https://vladmandic.github.io/sdnext-docs/Offload/) and [quantization](https://vladmandic.github.io/sdnext-docs/Quantization/) are necessary! + - [OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) + - OmniGen2 is a powerful unified multimodal model that supports t2i and i2i workflows and uses 4B transformer with Qwen-VL-2.5 4B VLM + - available via *networks -> models -> reference* - [nVidia Cosmos-Predict2 T2I](https://research.nvidia.com/labs/dir/cosmos-predict2/) *2B and 14B* - Cosmos-Predict2 T2I is a new foundational model from Nvidia in two variants: small 2B and large 14B - available via *networks -> models -> reference* @@ -11,6 +15,7 @@ - *note*: this is a gated model, you need to [accept terms](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/) - [Black Forest Labs FLUX.1 Kontext I2I](https://bfl.ai/announcements/flux-1-kontext-dev) *Dev* variant - FLUX.1-Kontext is a 12B model billion parameter capable of editing images based on text instructions + - model is primarily designed for image editing workflows, but also works for text-to-image workflows - requirements are similar to regular FLUX.1 although 2x slower - available via *networks -> models -> reference* - *note*: this is a gated model, you need to [accept terms](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/) diff --git a/TODO.md b/TODO.md index c8d6183e9..17963a7af 100644 --- a/TODO.md +++ b/TODO.md @@ -52,7 +52,6 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma - [SkyReels-v2](https://github.com/SkyworkAI/SkyReels-V2)(https://github.com/huggingface/diffusers/pull/11518) #### External:Unified/MultiModal - [Bagel](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT)(https://github.com/bytedance-seed/bagel) -- [OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) - [Ming](https://github.com/inclusionAI/Ming) - [Liquid](https://github.com/FoundationVision/Liquid) #### External:Image2Image/Editing diff --git a/html/reference.json b/html/reference.json index 599ad1c5b..e90aad1d1 100644 --- a/html/reference.json +++ b/html/reference.json @@ -252,6 +252,13 @@ "skip": true }, + "VectorSpaceLab OmniGen v2": { + "path": "OmniGen2/OmniGen2", + "desc": "OmniGen2 is a powerful and efficient unified multimodal model. Unlike OmniGen v1, OmniGen2 features two distinct decoding pathways for text and image modalities, utilizing unshared parameters and a decoupled image tokenizer.", + "preview": "OmniGen2--OmniGen2.jpg", + "skip": true + }, + "AuraFlow 0.3": { "path": "fal/AuraFlow-v0.3", "desc": "AuraFlow v0.3 is the fully open-sourced flow-based text-to-image generation model. The model was trained with more compute compared to the previous version, AuraFlow-v0.2. Compared to AuraFlow-v0.2, the model is fine-tuned on more aesthetic datasets and now supports various aspect ratio, (now width and height up to 1536 pixels).", diff --git a/models/Reference/OmniGen2--OmniGen2.jpg b/models/Reference/OmniGen2--OmniGen2.jpg new file mode 100644 index 000000000..d8fce4558 Binary files /dev/null and b/models/Reference/OmniGen2--OmniGen2.jpg differ diff --git a/modules/model_omnigen.py b/modules/model_omnigen.py index 7de3a93fe..b2fb95830 100644 --- a/modules/model_omnigen.py +++ b/modules/model_omnigen.py @@ -9,21 +9,6 @@ def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=u repo_id = sd_models.path_to_repo(checkpoint_info.name) vae = None - if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic': - try: - debug(f'Load model: type=OmniGen vae="{shared.opts.sd_vae}"') - from modules import sd_vae - # vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override') - vae_file = sd_vae.vae_dict[shared.opts.sd_vae] - if os.path.exists(vae_file): - vae_config = os.path.join('configs', 'sdxl', 'vae', 'config.json') - vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config) - except Exception as e: - shared.log.error(f"Load model: type=OmniGen failed to load VAE: {e}") - shared.opts.sd_vae = 'Default' - if debug: - errors.display(e, 'OmniGen VAE:') - load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model') transformer = diffusers.OmniGenTransformer2DModel.from_pretrained( repo_id, diff --git a/modules/model_omnigen2.py b/modules/model_omnigen2.py new file mode 100644 index 000000000..285e46f50 --- /dev/null +++ b/modules/model_omnigen2.py @@ -0,0 +1,49 @@ +import os +from modules import shared, devices, sd_models, model_quant + +debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None + + +def load_omnigen2(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument + repo_id = sd_models.path_to_repo(checkpoint_info.name) + + from modules.omnigen2 import OmniGen2Pipeline, OmniGen2Transformer2DModel, Qwen2_5_VLForConditionalGeneration + import diffusers + from diffusers import pipelines + diffusers.OmniGen2Pipeline = OmniGen2Pipeline # monkey-pathch + pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["omnigen2"] = diffusers.OmniGen2Pipeline + pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["omnigen2"] = diffusers.OmniGen2Pipeline + pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["omnigen2"] = diffusers.OmniGen2Pipeline + + load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model') + transformer = OmniGen2Transformer2DModel.from_pretrained( + repo_id, + subfolder="transformer", + cache_dir=shared.opts.diffusers_dir, + trust_remote_code=True, + **load_config, + **quant_config, + ) + + load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='TE') + mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( + repo_id, + subfolder="mllm", + cache_dir=shared.opts.diffusers_dir, + trust_remote_code=True, + **load_config, + **quant_config, + ) + + pipe = OmniGen2Pipeline.from_pretrained( + repo_id, + # transformer=transformer, + mllm=mllm, + cache_dir=shared.opts.diffusers_dir, + trust_remote_code=True, + **load_config, + ) + pipe.transformer = transformer # for omnigen2 transformer must be loaded after pipeline + + devices.torch_gc(force=True) + return pipe diff --git a/modules/modeldata.py b/modules/modeldata.py index 065bbd09d..789cccfef 100644 --- a/modules/modeldata.py +++ b/modules/modeldata.py @@ -35,6 +35,8 @@ def get_model_type(pipe): model_type = 'lumina2' elif "Lumina" in name: model_type = 'lumina' + elif "OmniGen2" in name: + model_type = 'omnigen2' elif "OmniGen" in name: model_type = 'omnigen' elif "CogView3" in name: diff --git a/modules/omnigen2/__init__.py b/modules/omnigen2/__init__.py new file mode 100644 index 000000000..82d539e60 --- /dev/null +++ b/modules/omnigen2/__init__.py @@ -0,0 +1,3 @@ +from transformers import Qwen2_5_VLForConditionalGeneration +from .pipeline_omnigen2 import OmniGen2Pipeline +from .models.transformers import OmniGen2Transformer2DModel diff --git a/modules/omnigen2/image_processor.py b/modules/omnigen2/image_processor.py new file mode 100644 index 000000000..00b1fb3ed --- /dev/null +++ b/modules/omnigen2/image_processor.py @@ -0,0 +1,265 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist +from diffusers.configuration_utils import register_to_config + +class OmniGen2ImageProcessor(VaeImageProcessor): + """ + Image processor for PixArt image resize and crop. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + resample: str = "lanczos", + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + self.max_pixels = max_pixels + self.max_side_length = max_side_length + + def get_new_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + ) -> Tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + if max_side_length is None: + max_side_length = self.max_side_length + + if max_pixels is None: + max_pixels = self.max_pixels + + ratio = 1.0 + if max_side_length is not None: + if height > width: + max_side_length_ratio = max_side_length / height + else: + max_side_length_ratio = max_side_length / width + + cur_pixels = height * width + max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image + + new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor + return new_height, new_width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", # "default", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Args: + image (`PipelineImageInput`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of + supported formats. + height (`int`, *optional*): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + if not is_valid_image_imagelist(image): + raise ValueError( + f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" + ) + if not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + if self.config.do_resize: + height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length) + image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + + image = self.numpy_to_pt(image) + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] + # don't need any preprocess if the image is latents + if channel == self.config.vae_latent_channels: + return image + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if do_normalize and image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + if do_normalize: + image = self.normalize(image) + + if self.config.do_binarize: + image = self.binarize(image) + + return image diff --git a/modules/omnigen2/import_utils.py b/modules/omnigen2/import_utils.py new file mode 100644 index 000000000..148b88684 --- /dev/null +++ b/modules/omnigen2/import_utils.py @@ -0,0 +1,46 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.util +import sys + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + +_triton_available, _triton_version = _is_package_available("triton") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") + +def is_triton_available(): + return _triton_available + +def is_flash_attn_available(): + return _flash_attn_available diff --git a/modules/omnigen2/models/__init__.py b/modules/omnigen2/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/omnigen2/models/attention_processor.py b/modules/omnigen2/models/attention_processor.py new file mode 100644 index 000000000..592d6b503 --- /dev/null +++ b/modules/omnigen2/models/attention_processor.py @@ -0,0 +1,357 @@ +""" +OmniGen2 Attention Processor Module + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import warnings +import math +from typing import Optional, Tuple, Dict, Any + +import torch +import torch.nn.functional as F +from einops import repeat + +from ..import_utils import is_flash_attn_available + +if is_flash_attn_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +else: + warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance") + + +from diffusers.models.attention_processor import Attention +from .embeddings import apply_rotary_emb + + +class OmniGen2AttnProcessorFlash2Varlen: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not is_flash_attn_available(): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. " + "Please install flash_attn." + ) + + def _upad_input( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + num_heads: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + """ + Unpad the input tensors for flash attention. + + Args: + query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + attention_mask: Attention mask tensor of shape (batch_size, seq_len) + query_length: Length of the query sequence + num_heads: Number of attention heads + + Returns: + Tuple containing: + - Unpadded query tensor + - Unpadded key tensor + - Unpadded value tensor + - Query indices + - Tuple of cumulative sequence lengths for query and key + - Tuple of maximum sequence lengths for query and key + """ + def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Helper function to get unpadding data from attention mask.""" + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + # Unpad key and value layers + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + + # Handle different query length cases + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # Unpad input for flash attention + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + # Handle different number of heads + if kv_heads < attn.heads: + key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + + # Apply flash attention + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + + # Pad output and apply final transformations + hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class OmniGen2AttnProcessor: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor is optimized for PyTorch 2.0 and implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + + Raises: + ImportError: If PyTorch version is less than 2.0 + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " + "Please upgrade PyTorch to version 2.0 or later." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/modules/omnigen2/models/embeddings.py b/modules/omnigen2/models/embeddings.py new file mode 100644 index 000000000..047526f69 --- /dev/null +++ b/modules/omnigen2/models/embeddings.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + + +from diffusers.models.activations import get_activation + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.zeros_(self.linear_1.bias) + nn.init.normal_(self.linear_2.weight, std=0.02) + nn.init.zeros_(self.linear_2.bias) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen and CogView4 + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) diff --git a/modules/omnigen2/models/transformers/__init__.py b/modules/omnigen2/models/transformers/__init__.py new file mode 100644 index 000000000..b2a23df90 --- /dev/null +++ b/modules/omnigen2/models/transformers/__init__.py @@ -0,0 +1,3 @@ +from .transformer_omnigen2 import OmniGen2Transformer2DModel + +__all__ = ["OmniGen2Transformer2DModel"] diff --git a/modules/omnigen2/models/transformers/block_lumina2.py b/modules/omnigen2/models/transformers/block_lumina2.py new file mode 100644 index 000000000..e6d36e0f1 --- /dev/null +++ b/modules/omnigen2/models/transformers/block_lumina2.py @@ -0,0 +1,217 @@ + +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from diffusers.models.embeddings import Timesteps +from ..embeddings import TimestepEmbedding +from ...import_utils import is_flash_attn_available, is_triton_available + +if is_triton_available(): + from ...triton_layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") + +if is_flash_attn_available(): + from flash_attn.ops.activations import swiglu +else: + from .components import swiglu + warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance") + +# try: +# from flash_attn.ops.activations import swiglu as fused_swiglu +# FUSEDSWIGLU_AVALIBLE = True +# except ImportError: + +# FUSEDSWIGLU_AVALIBLE = False +# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + text_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(text_feat_dim, eps=norm_eps), + nn.Linear(text_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed diff --git a/modules/omnigen2/models/transformers/components.py b/modules/omnigen2/models/transformers/components.py new file mode 100644 index 000000000..05dd6f5f3 --- /dev/null +++ b/modules/omnigen2/models/transformers/components.py @@ -0,0 +1,4 @@ +import torch.nn.functional as F + +def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y diff --git a/modules/omnigen2/models/transformers/repo.py b/modules/omnigen2/models/transformers/repo.py new file mode 100644 index 000000000..8f7c47566 --- /dev/null +++ b/modules/omnigen2/models/transformers/repo.py @@ -0,0 +1,129 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from einops import repeat +from diffusers.models.embeddings import get_1d_rotary_pos_embed + +class OmniGen2RotaryPosEmbed(nn.Module): + def __init__(self, theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis(axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int], + theta: int) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] + + max_seq_len = max(seq_lengths) + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + # add text position ids + position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + # add image position ids + + row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() + col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() + col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len: seq_len, 0] = pe_shift + position_ids[i, pe_shift_len: seq_len, 1] = row_ids + position_ids[i, pe_shift_len: seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + ref_img_freqs_cis = torch.zeros( + batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + img_freqs_cis = torch.zeros( + batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + ) diff --git a/modules/omnigen2/models/transformers/transformer_omnigen2.py b/modules/omnigen2/models/transformers/transformer_omnigen2.py new file mode 100644 index 000000000..fe324b0d2 --- /dev/null +++ b/modules/omnigen2/models/transformers/transformer_omnigen2.py @@ -0,0 +1,617 @@ +import warnings +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from einops import rearrange + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin + +from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor +from .repo import OmniGen2RotaryPosEmbed +from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding + +from ...import_utils import is_triton_available, is_flash_attn_available + +if is_triton_available(): + from ...triton_layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + +logger = logging.get_logger(__name__) + + +class OmniGen2TransformerBlock(nn.Module): + """ + Transformer block for OmniGen2 model. + + This block implements a transformer layer with: + - Multi-head attention with flash attention + - Feed-forward network with SwiGLU activation + - RMS normalization + - Optional modulation for conditional generation + + Args: + dim: Dimension of the input and output tensors + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for the feed-forward network dimension + norm_eps: Epsilon value for normalization layers + modulation: Whether to use modulation for conditional generation + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + try: + processor = OmniGen2AttnProcessorFlash2Varlen() + except ImportError: + processor = OmniGen2AttnProcessor() + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=processor, + ) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the transformer block. + + Uses Xavier uniform initialization for linear layers and zero initialization for biases. + """ + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the transformer block. + + Args: + hidden_states: Input hidden states tensor + attention_mask: Attention mask tensor + image_rotary_emb: Rotary embeddings for image tokens + temb: Optional timestep embedding tensor + + Returns: + torch.Tensor: Output hidden states after transformer block processing + """ + import time + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + OmniGen2 Transformer 2D Model. + + A transformer-based diffusion model for image generation with: + - Patch-based image processing + - Rotary position embeddings + - Multi-head attention + - Conditional generation support + + Args: + patch_size: Size of image patches + in_channels: Number of input channels + out_channels: Number of output channels (defaults to in_channels) + hidden_size: Size of hidden layers + num_layers: Number of transformer layers + num_refiner_layers: Number of refiner layers + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for feed-forward network dimension + norm_eps: Epsilon value for normalization layers + axes_dim_rope: Dimensions for rotary position embeddings + axes_lens: Lengths for rotary position embeddings + text_feat_dim: Dimension of text features + timestep_scale: Scale factor for timestep embeddings + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Omnigen2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + text_feat_dim: int = 1024, + timestep_scale: float = 1.0 + ) -> None: + """Initialize the OmniGen2 transformer model.""" + super().__init__() + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + self.out_channels = out_channels or in_channels + + # Initialize embeddings + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=text_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale + ) + + # Initialize transformer blocks + self.noise_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.ref_image_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.context_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels + ) + + # Add learnable embeddings to distinguish different images + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images + + self.gradient_checkpointing = False + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the model. + + Uses Xavier uniform initialization for linear layers. + """ + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb + ): + batch_size = len(hidden_states) + max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j] + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size) + batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # sequence of ref imgs to batch + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len] + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + # refine ref imgs separately + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb) + + # batch of ref imgs to sequence + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len] + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) + for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): + combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)] + combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + + if ref_image_hidden_states is not None: + ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states] + l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # ref image patch embeddings + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # image patch embeddings + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] + padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Condition, positional & patch embedding + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # 2. Context refinement + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # 3. Joint Transformer blocks + max_seq_len = max(seq_lengths) + + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len] + + hidden_states = joint_hidden_states + + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + + p = self.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): + height, width = img_size + output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p)) + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) diff --git a/modules/omnigen2/pipeline_omnigen2.py b/modules/omnigen2/pipeline_omnigen2.py new file mode 100644 index 000000000..a7d6c2ca3 --- /dev/null +++ b/modules/omnigen2/pipeline_omnigen2.py @@ -0,0 +1,718 @@ +""" +OmniGen2 Diffusion Pipeline + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +import inspect +import numpy as np +import torch +import torch.nn.functional as F +import PIL.Image + +from transformers import Qwen2_5_VLForConditionalGeneration +from diffusers.utils import BaseOutput +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from .models.transformers import OmniGen2Transformer2DModel +from .models.transformers.repo import OmniGen2RotaryPosEmbed +from .image_processor import OmniGen2ImageProcessor + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class FMPipelineOutput(BaseOutput): + """ + Output class for OmniGen2 pipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGen2Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using OmniGen2. + + This pipeline implements a text-to-image generation model that uses: + - Qwen2.5-VL for text encoding + - A custom transformer architecture for image generation + - VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler for noise scheduling + + Args: + transformer (OmniGen2Transformer2DModel): The transformer model for image generation. + vae (AutoencoderKL): The VAE model for image encoding/decoding. + scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling. + text_encoder (Qwen2_5_VLModel): The text encoder model. + tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + + def __init__( + self, + transformer: OmniGen2Transformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen2_5_VLForConditionalGeneration, + processor, + ) -> None: + """ + Initialize the OmniGen2 pipeline. + + Args: + transformer: The transformer model for image generation. + vae: The VAE model for image encoding/decoding. + scheduler: The scheduler for noise scheduling. + text_encoder: The text encoder model. + tokenizer: The tokenizer for text processing. + """ + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def prepare_image( + self, + images: Union[List[PIL.Image.Image], PIL.Image.Image], + batch_size: int, + num_images_per_prompt: int, + max_pixels: int, + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_prompt: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + if batch_size == 1: + images = [images] + latents = [] + for i, img in enumerate(images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + for _ in range(num_images_per_prompt): + latents.append(ref_latents) + + return latents + + def _get_qwen2_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get prompt embeddings from the Qwen2 text encoder. + + Args: + prompt: The prompt or list of prompts to encode. + device: The device to place the embeddings on. If None, uses the pipeline's device. + max_sequence_length: Maximum sequence length for tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The prompt embeddings tensor + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + # text_inputs = self.processor.tokenizer( + # prompt, + # padding="max_length", + # max_length=max_sequence_length, + # truncation=True, + # return_tensors="pt", + # ) + text_inputs = self.processor.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.mllm( + text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-1] + + if self.mllm is not None: + dtype = self.mllm.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def _apply_chat_template(self, prompt: str): + prompt = [ + { + "role": "system", + "content": "You are a helpful assistant that generates high-quality images based on user instructions.", + }, + {"role": "user", "content": prompt}, + ] + prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False) + return prompt + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [self._apply_chat_template(_prompt) for _prompt in prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt] + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=negative_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: int = 2048 * 2048, + max_input_image_side_length: int = 2048, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.text_guidance_scale > 1.0, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + + dtype = self.vae.dtype + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + max_pixels=max_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16 + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + verbose=verbose, + step_func=step_func, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + else: + return FMPipelineOutput(images=image) + + def processing( + self, + latents, + ref_latents, + prompt_embeds, + freqs_cis, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + verbose, + step_func=None + ): + batch_size = latents.shape[0] + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + num_tokens=latents.shape[-2] * latents.shape[-1] + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_pred = self.predict( + t=t, + latents=latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if image_guidance_scale != 1: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + else: + model_pred_uncond = torch.zeros_like(model_pred) + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + **optional_kwargs + ) + return model_pred diff --git a/modules/omnigen2/pipeline_utils.py b/modules/omnigen2/pipeline_utils.py new file mode 100644 index 000000000..4efebc260 --- /dev/null +++ b/modules/omnigen2/pipeline_utils.py @@ -0,0 +1,62 @@ +import torch + + +def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): + """ Get pipeline embeds for prompts bigger than the maxlength of the pipe + :param pipeline: + :param prompt: + :param negative_prompt: + :param device: + :return: + """ + max_length = pipeline.tokenizer.model_max_length + + # simple way to determine length of tokens + # count_prompt = len(prompt.split(" ")) + # count_negative_prompt = len(negative_prompt.split(" ")) + + # create the tensor based on which prompt is longer + # if count_prompt >= count_negative_prompt: + input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device) + # input_ids = pipeline.tokenizer(prompt, padding="max_length", + # max_length=pipeline.tokenizer.model_max_length, + # truncation=True, + # return_tensors="pt",).input_ids.to(device) + shape_max_length = input_ids.shape[-1] + + if negative_prompt is not None: + negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length", + max_length=shape_max_length, return_tensors="pt").input_ids.to(device) + + # else: + # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device) + # shape_max_length = negative_ids.shape[-1] + # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length", + # max_length=shape_max_length).input_ids.to(device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, max_length): + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + if negative_prompt is not None: + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + concat_embeds = torch.cat(concat_embeds, dim=1) + + if negative_prompt is not None: + neg_embeds = torch.cat(neg_embeds, dim=1) + else: + neg_embeds = None + + return concat_embeds, neg_embeds diff --git a/modules/omnigen2/triton_layer_norm.py b/modules/omnigen2/triton_layer_norm.py new file mode 100644 index 000000000..51a70b990 --- /dev/null +++ b/modules/omnigen2/triton_layer_norm.py @@ -0,0 +1,1257 @@ +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + zero_centered_weight=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + x_shape_og = x.shape + # Check for zero sequence length + if x.numel() == 0: + ctx.zero_seq_length = True + # Only save minimal required tensors for backward + # ctx.save_for_backward(weight, bias, weight1, bias1) + ctx.x_shape_og = x_shape_og + ctx.weight_shape = weight.shape + ctx.weight_dtype = weight.dtype + ctx.weight_device = weight.device + + ctx.has_bias = bias is not None + ctx.bias_shape = bias.shape if bias is not None else None + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.bias_device = bias.device if bias is not None else None + + ctx.has_weight1 = weight1 is not None + ctx.weight1_shape = weight1.shape if weight1 is not None else None + ctx.weight1_dtype = weight1.dtype if weight1 is not None else None + ctx.weight1_device = weight1.device if weight1 is not None else None + + ctx.has_bias1 = bias1 is not None + ctx.bias1_shape = bias1.shape if bias1 is not None else None + ctx.bias1_dtype = bias1.dtype if bias1 is not None else None + ctx.bias1_device = bias1.device if bias1 is not None else None + + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.dropout_p = dropout_p + + # Handle output tensors with correct dtype + y = x # Preserve input tensor properties + y1 = torch.empty_like(x) if x1 is not None else None + + # Only create residual_out if prenorm is True + residual_out = torch.empty(x.shape, + dtype=torch.float32 if residual_in_fp32 else x.dtype, + device=x.device) if prenorm else None + + # Handle dropout masks + dropout_mask = None + dropout_mask1 = None + if return_dropout_mask: + dropout_mask = torch.empty_like(x, dtype=torch.uint8) + if x1 is not None: + dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) + + # Return based on configuration + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ((y, dropout_mask, dropout_mask1) if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1)) + else: + return ((y, y1, dropout_mask, dropout_mask1) if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1)) + + ctx.zero_seq_length = False + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + if ctx.zero_seq_length: + return ( + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), + torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), + torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, + torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, + torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out, + residual_out + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out, + residual_out + ) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.zero_centered_weight = zero_centered_weight + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/modules/processing_args.py b/modules/processing_args.py index 154f5e272..7721fe371 100644 --- a/modules/processing_args.py +++ b/modules/processing_args.py @@ -61,7 +61,7 @@ def task_specific_kwargs(p, model): p.width, p.height = p.width // vae_scale_factor * vae_scale_factor, p.height // vae_scale_factor * vae_scale_factor task_args['max_area'] = max_area task_args['width'], task_args['height'] = p.width, p.height - if model.__class__.__name__ == 'OmniGenPipeline': + elif model.__class__.__name__ == 'OmniGenPipeline' or model.__class__.__name__ == 'OmniGen2Pipeline': p.width, p.height = 16 * math.ceil(p.init_images[0].width / 16), 16 * math.ceil(p.init_images[0].height / 16) task_args = { 'width': p.width, @@ -285,7 +285,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t kwargs['output_type'] = 'np' # only set latent if model has vae # model specific - if 'Kandinsky' in model.__class__.__name__: + if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__: kwargs['output_type'] = 'np' # only set latent if model has vae if 'StableCascade' in model.__class__.__name__: kwargs.pop("guidance_scale") # remove @@ -305,8 +305,6 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t args['control_strength'] = p.denoising_strength args['width'] = p.width args['height'] = p.height - if 'Cosmos2' in model.__class__.__name__: - kwargs['output_type'] = 'np' # cosmos uses wan-vae which is weird # set callbacks if 'prior_callback_steps' in possible: # Wuerstchen / Cascade args['prior_callback_steps'] = 1 diff --git a/modules/sd_detect.py b/modules/sd_detect.py index cb99f8d17..be57edb4b 100644 --- a/modules/sd_detect.py +++ b/modules/sd_detect.py @@ -88,6 +88,9 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): if 'omnigen' in f.lower(): guess = 'OmniGen' pipeline = 'custom' + if 'omnigen2' in f.lower(): + guess = 'OmniGen2' + pipeline = 'custom' if 'sd3' in f.lower(): guess = 'Stable Diffusion 3' if 'hidream' in f.lower(): diff --git a/modules/sd_models.py b/modules/sd_models.py index 97e36b2e6..3a74eac56 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -37,6 +37,7 @@ pipe_switch_task_exclude = [ 'InstantIRPipeline', 'LTXConditionPipeline', 'OmniGenPipeline', + 'OmniGen2Pipeline', 'PhotoMakerStableDiffusionXLPipeline', 'PixelSmithXLPipeline', 'StableDiffusion3ControlNetPipeline', @@ -364,6 +365,10 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' from modules.model_meissonic import load_meissonic sd_model = load_meissonic(checkpoint_info, diffusers_load_config) allow_post_quant = True + elif model_type in ['OmniGen2']: # forced pipeline + from modules.model_omnigen2 import load_omnigen2 + sd_model = load_omnigen2(checkpoint_info, diffusers_load_config) + allow_post_quant = False elif model_type in ['OmniGen']: # forced pipeline from modules.model_omnigen import load_omnigen sd_model = load_omnigen(checkpoint_info, diffusers_load_config) diff --git a/modules/sd_offload.py b/modules/sd_offload.py index 188275f71..037f46ca2 100644 --- a/modules/sd_offload.py +++ b/modules/sd_offload.py @@ -12,7 +12,7 @@ from modules.timer import process as process_timer debug = os.environ.get('SD_MOVE_DEBUG', None) is not None debug_move = log.trace if debug else lambda *args, **kwargs: None -offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4', 'cosmos', 'chroma'] +offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma'] offload_post = ['h1'] offload_hook_instance = None balanced_offload_exclude = ['CogView4Pipeline'] diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 2d6c2858c..15e2cf1ad 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 } -flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma'] +flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen2'] warned = False queue_lock = threading.Lock() diff --git a/modules/shared.py b/modules/shared.py index cc224dc00..d87d5b252 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -522,11 +522,11 @@ options_templates.update(options_section(("quantization", "Quantization Settings "sdnq_quantize_weights_mode": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": sdnq_quant_modes, "visible": native}), "sdnq_quantize_weights_mode_te": OptionInfo("default", "Quantization type for Text Encoders", gr.Dropdown, {"choices": ['default'] + sdnq_quant_modes, "visible": native}), "sdnq_quantize_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": native}), - "sdnq_quantize_conv_layers": OptionInfo(False, "Quantize the convolutional layers", gr.Checkbox, {"visible": native}), + "sdnq_quantize_conv_layers": OptionInfo(False, "Quantize convolutional layers", gr.Checkbox, {"visible": native}), "sdnq_dequantize_compile": OptionInfo(devices.has_triton(), "Dequantize using torch.compile", gr.Checkbox, {"visible": native}), - "sdnq_use_quantized_matmul": OptionInfo(False, "Use Quantized MatMul", gr.Checkbox, {"visible": native}), - "sdnq_use_quantized_matmul_conv": OptionInfo(False, "Use Quantized MatMul with convolutional layers", gr.Checkbox, {"visible": native}), - "sdnq_quantize_with_gpu": OptionInfo(True, "Quantize with the GPU", gr.Checkbox, {"visible": native}), + "sdnq_use_quantized_matmul": OptionInfo(False, "Use quantized MatMul", gr.Checkbox, {"visible": native}), + "sdnq_use_quantized_matmul_conv": OptionInfo(False, "Use quantized MatMul with convolutional layers", gr.Checkbox, {"visible": native}), + "sdnq_quantize_with_gpu": OptionInfo(True, "Quantize using GPU", gr.Checkbox, {"visible": native}), "sdnq_dequantize_fp32": OptionInfo(False, "Dequantize using full precision", gr.Checkbox, {"visible": native}), "sdnq_quantize_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox, {"visible": native}), diff --git a/wiki b/wiki index c1ae36ab4..711b61ebf 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit c1ae36ab4c2197487306c5c9fe4e328a038d1367 +Subproject commit 711b61ebfdb8cb09b06b4f8c627ae97519c6f74c