mirror of https://github.com/vladmandic/automatic
parent
01a960b19e
commit
bff75f0db3
|
|
@ -27,6 +27,7 @@ ignore-paths=/usr/lib/.*$,
|
|||
modules/meissonic,
|
||||
modules/mod,
|
||||
modules/omnigen,
|
||||
modules/omnigen2,
|
||||
modules/onnx_impl,
|
||||
modules/pag,
|
||||
modules/pixelsmith,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ exclude = [
|
|||
"modules/meissonic",
|
||||
"modules/mod",
|
||||
"modules/omnigen",
|
||||
"modules/omnigen2",
|
||||
"modules/hidream",
|
||||
"modules/pag",
|
||||
"modules/pixelsmith",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@
|
|||
|
||||
- **Models**
|
||||
- [Models Wiki page](https://vladmandic.github.io/sdnext-docs/Models/) is updated will all new models
|
||||
*note* all new image models larger than 30GB, so [offloading](https://vladmandic.github.io/sdnext-docs/Offload/) and [quantization](https://vladmandic.github.io/sdnext-docs/Quantization/) are necessary!
|
||||
- [OmniGen2](https://huggingface.co/OmniGen2/OmniGen2)
|
||||
- OmniGen2 is a powerful unified multimodal model that supports t2i and i2i workflows and uses 4B transformer with Qwen-VL-2.5 4B VLM
|
||||
- available via *networks -> models -> reference*
|
||||
- [nVidia Cosmos-Predict2 T2I](https://research.nvidia.com/labs/dir/cosmos-predict2/) *2B and 14B*
|
||||
- Cosmos-Predict2 T2I is a new foundational model from Nvidia in two variants: small 2B and large 14B
|
||||
- available via *networks -> models -> reference*
|
||||
|
|
@ -11,6 +15,7 @@
|
|||
- *note*: this is a gated model, you need to [accept terms](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/)
|
||||
- [Black Forest Labs FLUX.1 Kontext I2I](https://bfl.ai/announcements/flux-1-kontext-dev) *Dev* variant
|
||||
- FLUX.1-Kontext is a 12B model billion parameter capable of editing images based on text instructions
|
||||
- model is primarily designed for image editing workflows, but also works for text-to-image workflows
|
||||
- requirements are similar to regular FLUX.1 although 2x slower
|
||||
- available via *networks -> models -> reference*
|
||||
- *note*: this is a gated model, you need to [accept terms](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/)
|
||||
|
|
|
|||
1
TODO.md
1
TODO.md
|
|
@ -52,7 +52,6 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma
|
|||
- [SkyReels-v2](https://github.com/SkyworkAI/SkyReels-V2)(https://github.com/huggingface/diffusers/pull/11518)
|
||||
#### External:Unified/MultiModal
|
||||
- [Bagel](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT)(https://github.com/bytedance-seed/bagel)
|
||||
- [OmniGen2](https://huggingface.co/OmniGen2/OmniGen2)
|
||||
- [Ming](https://github.com/inclusionAI/Ming)
|
||||
- [Liquid](https://github.com/FoundationVision/Liquid)
|
||||
#### External:Image2Image/Editing
|
||||
|
|
|
|||
|
|
@ -252,6 +252,13 @@
|
|||
"skip": true
|
||||
},
|
||||
|
||||
"VectorSpaceLab OmniGen v2": {
|
||||
"path": "OmniGen2/OmniGen2",
|
||||
"desc": "OmniGen2 is a powerful and efficient unified multimodal model. Unlike OmniGen v1, OmniGen2 features two distinct decoding pathways for text and image modalities, utilizing unshared parameters and a decoupled image tokenizer.",
|
||||
"preview": "OmniGen2--OmniGen2.jpg",
|
||||
"skip": true
|
||||
},
|
||||
|
||||
"AuraFlow 0.3": {
|
||||
"path": "fal/AuraFlow-v0.3",
|
||||
"desc": "AuraFlow v0.3 is the fully open-sourced flow-based text-to-image generation model. The model was trained with more compute compared to the previous version, AuraFlow-v0.2. Compared to AuraFlow-v0.2, the model is fine-tuned on more aesthetic datasets and now supports various aspect ratio, (now width and height up to 1536 pixels).",
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
|
|
@ -9,21 +9,6 @@ def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=u
|
|||
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
||||
vae = None
|
||||
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=OmniGen vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
if os.path.exists(vae_file):
|
||||
vae_config = os.path.join('configs', 'sdxl', 'vae', 'config.json')
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=OmniGen failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'OmniGen VAE:')
|
||||
|
||||
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model')
|
||||
transformer = diffusers.OmniGenTransformer2DModel.from_pretrained(
|
||||
repo_id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_omnigen2(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
||||
|
||||
from modules.omnigen2 import OmniGen2Pipeline, OmniGen2Transformer2DModel, Qwen2_5_VLForConditionalGeneration
|
||||
import diffusers
|
||||
from diffusers import pipelines
|
||||
diffusers.OmniGen2Pipeline = OmniGen2Pipeline # monkey-pathch
|
||||
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["omnigen2"] = diffusers.OmniGen2Pipeline
|
||||
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["omnigen2"] = diffusers.OmniGen2Pipeline
|
||||
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["omnigen2"] = diffusers.OmniGen2Pipeline
|
||||
|
||||
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model')
|
||||
transformer = OmniGen2Transformer2DModel.from_pretrained(
|
||||
repo_id,
|
||||
subfolder="transformer",
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
trust_remote_code=True,
|
||||
**load_config,
|
||||
**quant_config,
|
||||
)
|
||||
|
||||
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='TE')
|
||||
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
repo_id,
|
||||
subfolder="mllm",
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
trust_remote_code=True,
|
||||
**load_config,
|
||||
**quant_config,
|
||||
)
|
||||
|
||||
pipe = OmniGen2Pipeline.from_pretrained(
|
||||
repo_id,
|
||||
# transformer=transformer,
|
||||
mllm=mllm,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
trust_remote_code=True,
|
||||
**load_config,
|
||||
)
|
||||
pipe.transformer = transformer # for omnigen2 transformer must be loaded after pipeline
|
||||
|
||||
devices.torch_gc(force=True)
|
||||
return pipe
|
||||
|
|
@ -35,6 +35,8 @@ def get_model_type(pipe):
|
|||
model_type = 'lumina2'
|
||||
elif "Lumina" in name:
|
||||
model_type = 'lumina'
|
||||
elif "OmniGen2" in name:
|
||||
model_type = 'omnigen2'
|
||||
elif "OmniGen" in name:
|
||||
model_type = 'omnigen'
|
||||
elif "CogView3" in name:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
from .pipeline_omnigen2 import OmniGen2Pipeline
|
||||
from .models.transformers import OmniGen2Transformer2DModel
|
||||
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
|
||||
class OmniGen2ImageProcessor(VaeImageProcessor):
|
||||
"""
|
||||
Image processor for PixArt image resize and crop.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
||||
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
||||
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
||||
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
||||
resample (`str`, *optional*, defaults to `lanczos`):
|
||||
Resampling filter to use when resizing the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image to [-1,1].
|
||||
do_binarize (`bool`, *optional*, defaults to `False`):
|
||||
Whether to binarize the image to 0/1.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to RGB format.
|
||||
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to grayscale format.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 16,
|
||||
resample: str = "lanczos",
|
||||
max_pixels: Optional[int] = None,
|
||||
max_side_length: Optional[int] = None,
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
do_convert_grayscale: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
resample=resample,
|
||||
do_normalize=do_normalize,
|
||||
do_binarize=do_binarize,
|
||||
do_convert_grayscale=do_convert_grayscale,
|
||||
)
|
||||
|
||||
self.max_pixels = max_pixels
|
||||
self.max_side_length = max_side_length
|
||||
|
||||
def get_new_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
max_side_length: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
|
||||
Args:
|
||||
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
||||
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
||||
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
||||
tensor, it should have shape `[batch, channels, height, width]`.
|
||||
height (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
||||
width (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`:
|
||||
A tuple containing the height and width, both resized to the nearest integer multiple of
|
||||
`vae_scale_factor`.
|
||||
"""
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
else:
|
||||
height = image.shape[1]
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
else:
|
||||
width = image.shape[2]
|
||||
|
||||
if max_side_length is None:
|
||||
max_side_length = self.max_side_length
|
||||
|
||||
if max_pixels is None:
|
||||
max_pixels = self.max_pixels
|
||||
|
||||
ratio = 1.0
|
||||
if max_side_length is not None:
|
||||
if height > width:
|
||||
max_side_length_ratio = max_side_length / height
|
||||
else:
|
||||
max_side_length_ratio = max_side_length / width
|
||||
|
||||
cur_pixels = height * width
|
||||
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
|
||||
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
|
||||
|
||||
new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
|
||||
return new_height, new_width
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
max_side_length: Optional[int] = None,
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
||||
supported formats.
|
||||
height (`int`, *optional*):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
||||
height.
|
||||
width (`int`, *optional*):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
||||
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
||||
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
||||
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
||||
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
||||
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
||||
supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preprocessed image.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
||||
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
||||
if isinstance(image, torch.Tensor):
|
||||
# if image is a pytorch tensor could have 2 possible shapes:
|
||||
# 1. batch x height x width: we should insert the channel dimension at position 1
|
||||
# 2. channel x height x width: we should insert batch dimension at position 0,
|
||||
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
||||
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
||||
image = image.unsqueeze(1)
|
||||
else:
|
||||
# if it is a numpy array, it could have 2 possible shapes:
|
||||
# 1. batch x height x width: insert channel dimension on last position
|
||||
# 2. height x width x channel: insert batch dimension on first position
|
||||
if image.shape[-1] == 1:
|
||||
image = np.expand_dims(image, axis=0)
|
||||
else:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
||||
warnings.warn(
|
||||
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
||||
FutureWarning,
|
||||
)
|
||||
image = np.concatenate(image, axis=0)
|
||||
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
||||
warnings.warn(
|
||||
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
||||
FutureWarning,
|
||||
)
|
||||
image = torch.cat(image, axis=0)
|
||||
|
||||
if not is_valid_image_imagelist(image):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
||||
)
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if crops_coords is not None:
|
||||
image = [i.crop(crops_coords) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
|
||||
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
||||
if self.config.do_convert_rgb:
|
||||
image = [self.convert_to_rgb(i) for i in image]
|
||||
elif self.config.do_convert_grayscale:
|
||||
image = [self.convert_to_grayscale(i) for i in image]
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||
|
||||
image = self.numpy_to_pt(image)
|
||||
|
||||
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
||||
if self.config.do_resize:
|
||||
image = self.resize(image, height, width)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
|
||||
if self.config.do_convert_grayscale and image.ndim == 3:
|
||||
image = image.unsqueeze(1)
|
||||
|
||||
channel = image.shape[1]
|
||||
# don't need any preprocess if the image is latents
|
||||
if channel == self.config.vae_latent_channels:
|
||||
return image
|
||||
|
||||
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
||||
if self.config.do_resize:
|
||||
image = self.resize(image, height, width)
|
||||
|
||||
# expected range [0,1], normalize to [-1,1]
|
||||
do_normalize = self.config.do_normalize
|
||||
if do_normalize and image.min() < 0:
|
||||
warnings.warn(
|
||||
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
||||
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
||||
FutureWarning,
|
||||
)
|
||||
do_normalize = False
|
||||
if do_normalize:
|
||||
image = self.normalize(image)
|
||||
|
||||
if self.config.do_binarize:
|
||||
image = self.binarize(image)
|
||||
|
||||
return image
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
# The package importlib_metadata is in a different place, depending on the python version.
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
def _is_package_available(pkg_name: str):
|
||||
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
pkg_version = "N/A"
|
||||
|
||||
if pkg_exists:
|
||||
try:
|
||||
pkg_version = importlib_metadata.version(pkg_name)
|
||||
except (ImportError, importlib_metadata.PackageNotFoundError):
|
||||
pkg_exists = False
|
||||
|
||||
return pkg_exists, pkg_version
|
||||
|
||||
_triton_available, _triton_version = _is_package_available("triton")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
|
||||
def is_triton_available():
|
||||
return _triton_available
|
||||
|
||||
def is_flash_attn_available():
|
||||
return _flash_attn_available
|
||||
|
|
@ -0,0 +1,357 @@
|
|||
"""
|
||||
OmniGen2 Attention Processor Module
|
||||
|
||||
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
import math
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
from ..import_utils import is_flash_attn_available
|
||||
|
||||
if is_flash_attn_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
||||
else:
|
||||
warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
|
||||
|
||||
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
|
||||
class OmniGen2AttnProcessorFlash2Varlen:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
||||
|
||||
This processor implements:
|
||||
- Flash attention with variable length sequences
|
||||
- Rotary position embeddings (RoPE)
|
||||
- Query-Key normalization
|
||||
- Proportional attention scaling
|
||||
|
||||
Args:
|
||||
None
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the attention processor."""
|
||||
if not is_flash_attn_available():
|
||||
raise ImportError(
|
||||
"OmniGen2AttnProcessorFlash2Varlen requires flash_attn. "
|
||||
"Please install flash_attn."
|
||||
)
|
||||
|
||||
def _upad_input(
|
||||
self,
|
||||
query_layer: torch.Tensor,
|
||||
key_layer: torch.Tensor,
|
||||
value_layer: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
num_heads: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
||||
"""
|
||||
Unpad the input tensors for flash attention.
|
||||
|
||||
Args:
|
||||
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
|
||||
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
||||
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
||||
attention_mask: Attention mask tensor of shape (batch_size, seq_len)
|
||||
query_length: Length of the query sequence
|
||||
num_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Unpadded query tensor
|
||||
- Unpadded key tensor
|
||||
- Unpadded value tensor
|
||||
- Query indices
|
||||
- Tuple of cumulative sequence lengths for query and key
|
||||
- Tuple of maximum sequence lengths for query and key
|
||||
"""
|
||||
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Helper function to get unpadding data from attention mask."""
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
# Unpad key and value layers
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
|
||||
# Handle different query length cases
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
)
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process attention computation with flash attention.
|
||||
|
||||
Args:
|
||||
attn: Attention module
|
||||
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
||||
encoder_hidden_states: Encoder hidden states tensor
|
||||
attention_mask: Optional attention mask tensor
|
||||
image_rotary_emb: Optional rotary embeddings for image tokens
|
||||
base_sequence_length: Optional base sequence length for proportional attention
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed hidden states after attention computation
|
||||
"""
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Reshape tensors for attention computation
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply Rotary Position Embeddings
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Calculate attention scale
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# Unpad input for flash attention
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
indices_q,
|
||||
cu_seq_lens,
|
||||
max_seq_lens,
|
||||
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
# Handle different number of heads
|
||||
if kv_heads < attn.heads:
|
||||
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
||||
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
||||
|
||||
# Apply flash attention
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=softmax_scale,
|
||||
)
|
||||
|
||||
# Pad output and apply final transformations
|
||||
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
|
||||
hidden_states = hidden_states.flatten(-2)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# Apply output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2AttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
||||
|
||||
This processor is optimized for PyTorch 2.0 and implements:
|
||||
- Flash attention with variable length sequences
|
||||
- Rotary position embeddings (RoPE)
|
||||
- Query-Key normalization
|
||||
- Proportional attention scaling
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ImportError: If PyTorch version is less than 2.0
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the attention processor."""
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
|
||||
"Please upgrade PyTorch to version 2.0 or later."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process attention computation with flash attention.
|
||||
|
||||
Args:
|
||||
attn: Attention module
|
||||
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
||||
encoder_hidden_states: Encoder hidden states tensor
|
||||
attention_mask: Optional attention mask tensor
|
||||
image_rotary_emb: Optional rotary embeddings for image tokens
|
||||
base_sequence_length: Optional base sequence length for proportional attention
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed hidden states after attention computation
|
||||
"""
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Reshape tensors for attention computation
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply Rotary Position Embeddings
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Calculate attention scale
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
|
||||
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
||||
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# Apply output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
from diffusers.models.activations import get_activation
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
nn.init.normal_(self.linear_1.weight, std=0.02)
|
||||
nn.init.zeros_(self.linear_1.bias)
|
||||
nn.init.normal_(self.linear_2.weight, std=0.02)
|
||||
nn.init.zeros_(self.linear_2.bias)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen and CogView4
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
# used for lumina
|
||||
# x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .transformer_omnigen2 import OmniGen2Transformer2DModel
|
||||
|
||||
__all__ = ["OmniGen2Transformer2DModel"]
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
|
||||
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.models.embeddings import Timesteps
|
||||
from ..embeddings import TimestepEmbedding
|
||||
from ...import_utils import is_flash_attn_available, is_triton_available
|
||||
|
||||
if is_triton_available():
|
||||
from ...triton_layer_norm import RMSNorm
|
||||
else:
|
||||
from torch.nn import RMSNorm
|
||||
warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
|
||||
|
||||
if is_flash_attn_available():
|
||||
from flash_attn.ops.activations import swiglu
|
||||
else:
|
||||
from .components import swiglu
|
||||
warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance")
|
||||
|
||||
# try:
|
||||
# from flash_attn.ops.activations import swiglu as fused_swiglu
|
||||
# FUSEDSWIGLU_AVALIBLE = True
|
||||
# except ImportError:
|
||||
|
||||
# FUSEDSWIGLU_AVALIBLE = False
|
||||
# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
class LuminaRMSNormZero(nn.Module):
|
||||
"""
|
||||
Norm layer adaptive RMS normalization zero.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_eps: float,
|
||||
norm_elementwise_affine: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(
|
||||
min(embedding_dim, 1024),
|
||||
4 * embedding_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class LuminaLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
self.linear_2 = None
|
||||
if out_dim is not None:
|
||||
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
conditioning_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
scale = emb
|
||||
x = self.norm(x) * (1 + scale)[:, None, :]
|
||||
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
hidden_size (`int`):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
||||
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.swiglu = swiglu
|
||||
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
inner_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_3 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h1, h2 = self.linear_1(x), self.linear_3(x)
|
||||
return self.linear_2(self.swiglu(h1, h2))
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
text_feat_dim: int = 2048,
|
||||
frequency_embedding_size: int = 256,
|
||||
norm_eps: float = 1e-5,
|
||||
timestep_scale: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
|
||||
)
|
||||
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
||||
)
|
||||
|
||||
self.caption_embedder = nn.Sequential(
|
||||
RMSNorm(text_feat_dim, eps=norm_eps),
|
||||
nn.Linear(text_feat_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
|
||||
nn.init.zeros_(self.caption_embedder[1].bias)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
def swiglu(x, y):
|
||||
return F.silu(x.float(), inplace=False).to(x.dtype) * y
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import repeat
|
||||
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
||||
|
||||
class OmniGen2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int,
|
||||
axes_dim: Tuple[int, int, int],
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
|
||||
@staticmethod
|
||||
def get_freqs_cis(axes_dim: Tuple[int, int, int],
|
||||
axes_lens: Tuple[int, int, int],
|
||||
theta: int) -> List[torch.Tensor]:
|
||||
freqs_cis = []
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
||||
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
|
||||
freqs_cis.append(emb)
|
||||
return freqs_cis
|
||||
|
||||
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
|
||||
device = ids.device
|
||||
if ids.device.type == "mps":
|
||||
ids = ids.to("cpu")
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dim)):
|
||||
freqs = freqs_cis[i].to(ids.device)
|
||||
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1).to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
freqs_cis,
|
||||
attention_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
device
|
||||
):
|
||||
batch_size = len(attention_mask)
|
||||
p = self.patch_size
|
||||
|
||||
encoder_seq_len = attention_mask.shape[1]
|
||||
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
||||
|
||||
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
||||
|
||||
max_seq_len = max(seq_lengths)
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
# Create position IDs
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
# add text position ids
|
||||
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
||||
|
||||
pe_shift = cap_seq_len
|
||||
pe_shift_len = cap_seq_len
|
||||
|
||||
if ref_img_sizes[i] is not None:
|
||||
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
||||
H, W = ref_img_size
|
||||
ref_H_tokens, ref_W_tokens = H // p, W // p
|
||||
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
||||
# add image position ids
|
||||
|
||||
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
||||
|
||||
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
||||
pe_shift_len += ref_img_len
|
||||
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p, W // p
|
||||
assert H_tokens * W_tokens == l_effective_img_len[i]
|
||||
|
||||
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
||||
|
||||
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
||||
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
||||
|
||||
# Get combined rotary embeddings
|
||||
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
|
||||
|
||||
# create separate rotary embeddings for captions and images
|
||||
cap_freqs_cis = torch.zeros(
|
||||
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
ref_img_freqs_cis = torch.zeros(
|
||||
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
img_freqs_cis = torch.zeros(
|
||||
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
|
||||
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
||||
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
||||
|
||||
return (
|
||||
cap_freqs_cis,
|
||||
ref_img_freqs_cis,
|
||||
img_freqs_cis,
|
||||
freqs_cis,
|
||||
l_effective_cap_len,
|
||||
seq_lengths,
|
||||
)
|
||||
|
|
@ -0,0 +1,617 @@
|
|||
import warnings
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor
|
||||
from .repo import OmniGen2RotaryPosEmbed
|
||||
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
|
||||
|
||||
from ...import_utils import is_triton_available, is_flash_attn_available
|
||||
|
||||
if is_triton_available():
|
||||
from ...triton_layer_norm import RMSNorm
|
||||
else:
|
||||
from torch.nn import RMSNorm
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class OmniGen2TransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer block for OmniGen2 model.
|
||||
|
||||
This block implements a transformer layer with:
|
||||
- Multi-head attention with flash attention
|
||||
- Feed-forward network with SwiGLU activation
|
||||
- RMS normalization
|
||||
- Optional modulation for conditional generation
|
||||
|
||||
Args:
|
||||
dim: Dimension of the input and output tensors
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key-value heads
|
||||
multiple_of: Multiple of which the hidden dimension should be
|
||||
ffn_dim_multiplier: Multiplier for the feed-forward network dimension
|
||||
norm_eps: Epsilon value for normalization layers
|
||||
modulation: Whether to use modulation for conditional generation
|
||||
use_fused_rms_norm: Whether to use fused RMS normalization
|
||||
use_fused_swiglu: Whether to use fused SwiGLU activation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
modulation: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the transformer block."""
|
||||
super().__init__()
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.modulation = modulation
|
||||
|
||||
try:
|
||||
processor = OmniGen2AttnProcessorFlash2Varlen()
|
||||
except ImportError:
|
||||
processor = OmniGen2AttnProcessor()
|
||||
|
||||
# Initialize attention layer
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="rms_norm",
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Initialize feed-forward network
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier
|
||||
)
|
||||
|
||||
# Initialize normalization layers
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(
|
||||
embedding_dim=dim,
|
||||
norm_eps=norm_eps,
|
||||
norm_elementwise_affine=True
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
"""
|
||||
Initialize the weights of the transformer block.
|
||||
|
||||
Uses Xavier uniform initialization for linear layers and zero initialization for biases.
|
||||
"""
|
||||
nn.init.xavier_uniform_(self.attn.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.attn.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.attn.to_v.weight)
|
||||
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
|
||||
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
|
||||
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
|
||||
|
||||
if self.modulation:
|
||||
nn.init.zeros_(self.norm1.linear.weight)
|
||||
nn.init.zeros_(self.norm1.linear.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the transformer block.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states tensor
|
||||
attention_mask: Attention mask tensor
|
||||
image_rotary_emb: Rotary embeddings for image tokens
|
||||
temb: Optional timestep embedding tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output hidden states after transformer block processing
|
||||
"""
|
||||
import time
|
||||
if self.modulation:
|
||||
if temb is None:
|
||||
raise ValueError("temb must be provided when modulation is enabled")
|
||||
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
OmniGen2 Transformer 2D Model.
|
||||
|
||||
A transformer-based diffusion model for image generation with:
|
||||
- Patch-based image processing
|
||||
- Rotary position embeddings
|
||||
- Multi-head attention
|
||||
- Conditional generation support
|
||||
|
||||
Args:
|
||||
patch_size: Size of image patches
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels (defaults to in_channels)
|
||||
hidden_size: Size of hidden layers
|
||||
num_layers: Number of transformer layers
|
||||
num_refiner_layers: Number of refiner layers
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key-value heads
|
||||
multiple_of: Multiple of which the hidden dimension should be
|
||||
ffn_dim_multiplier: Multiplier for feed-forward network dimension
|
||||
norm_eps: Epsilon value for normalization layers
|
||||
axes_dim_rope: Dimensions for rotary position embeddings
|
||||
axes_lens: Lengths for rotary position embeddings
|
||||
text_feat_dim: Dimension of text features
|
||||
timestep_scale: Scale factor for timestep embeddings
|
||||
use_fused_rms_norm: Whether to use fused RMS normalization
|
||||
use_fused_swiglu: Whether to use fused SwiGLU activation
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Omnigen2TransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
text_feat_dim: int = 1024,
|
||||
timestep_scale: float = 1.0
|
||||
) -> None:
|
||||
"""Initialize the OmniGen2 transformer model."""
|
||||
super().__init__()
|
||||
|
||||
# Validate configuration
|
||||
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
|
||||
raise ValueError(
|
||||
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
|
||||
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
|
||||
)
|
||||
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
# Initialize embeddings
|
||||
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||
theta=10000,
|
||||
axes_dim=axes_dim_rope,
|
||||
axes_lens=axes_lens,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=hidden_size,
|
||||
)
|
||||
|
||||
self.ref_image_patch_embedder = nn.Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=hidden_size,
|
||||
)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size,
|
||||
text_feat_dim=text_feat_dim,
|
||||
norm_eps=norm_eps,
|
||||
timestep_scale=timestep_scale
|
||||
)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self.noise_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.ref_image_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=False
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
out_dim=patch_size * patch_size * self.out_channels
|
||||
)
|
||||
|
||||
# Add learnable embeddings to distinguish different images
|
||||
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
"""
|
||||
Initialize the weights of the model.
|
||||
|
||||
Uses Xavier uniform initialization for linear layers.
|
||||
"""
|
||||
nn.init.xavier_uniform_(self.x_embedder.weight)
|
||||
nn.init.constant_(self.x_embedder.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
||||
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
||||
|
||||
nn.init.zeros_(self.norm_out.linear_1.weight)
|
||||
nn.init.zeros_(self.norm_out.linear_1.bias)
|
||||
nn.init.zeros_(self.norm_out.linear_2.weight)
|
||||
nn.init.zeros_(self.norm_out.linear_2.bias)
|
||||
|
||||
nn.init.normal_(self.image_index_embedding, std=0.02)
|
||||
|
||||
def img_patch_embed_and_refine(
|
||||
self,
|
||||
hidden_states,
|
||||
ref_image_hidden_states,
|
||||
padded_img_mask,
|
||||
padded_ref_img_mask,
|
||||
noise_rotary_emb,
|
||||
ref_img_rotary_emb,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
temb
|
||||
):
|
||||
batch_size = len(hidden_states)
|
||||
max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
||||
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
|
||||
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
||||
num_ref_images = len(flat_l_effective_ref_img_len)
|
||||
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
||||
|
||||
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
||||
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
|
||||
batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
|
||||
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
||||
|
||||
# sequence of ref imgs to batch
|
||||
idx = 0
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for ref_img_len in l_effective_ref_img_len[i]:
|
||||
batch_ref_img_mask[idx, :ref_img_len] = True
|
||||
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
||||
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
||||
batch_temb[idx] = temb[i]
|
||||
shift += ref_img_len
|
||||
idx += 1
|
||||
|
||||
# refine ref imgs separately
|
||||
for layer in self.ref_image_refiner:
|
||||
batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
|
||||
|
||||
# batch of ref imgs to sequence
|
||||
idx = 0
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for ref_img_len in l_effective_ref_img_len[i]:
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
||||
shift += ref_img_len
|
||||
idx += 1
|
||||
|
||||
combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
|
||||
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
||||
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
||||
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
||||
|
||||
return combined_img_hidden_states
|
||||
|
||||
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
||||
batch_size = len(hidden_states)
|
||||
p = self.config.patch_size
|
||||
device = hidden_states[0].device
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
|
||||
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
||||
else:
|
||||
ref_img_sizes = [None for _ in range(batch_size)]
|
||||
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
||||
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
# ref image patch embeddings
|
||||
flat_ref_img_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
if ref_img_sizes[i] is not None:
|
||||
imgs = []
|
||||
for ref_img in ref_image_hidden_states[i]:
|
||||
C, H, W = ref_img.size()
|
||||
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
imgs.append(ref_img)
|
||||
|
||||
img = torch.cat(imgs, dim=0)
|
||||
flat_ref_img_hidden_states.append(img)
|
||||
else:
|
||||
flat_ref_img_hidden_states.append(None)
|
||||
|
||||
# image patch embeddings
|
||||
flat_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
img = hidden_states[i]
|
||||
C, H, W = img.size()
|
||||
|
||||
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
flat_hidden_states.append(img)
|
||||
|
||||
padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
||||
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
if ref_img_sizes[i] is not None:
|
||||
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
||||
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
||||
|
||||
padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
||||
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
||||
padded_img_mask[i, :l_effective_img_len[i]] = True
|
||||
|
||||
return (
|
||||
padded_hidden_states,
|
||||
padded_ref_img_hidden_states,
|
||||
padded_img_mask,
|
||||
padded_ref_img_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
|
||||
timestep: torch.Tensor,
|
||||
text_hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
text_attention_mask: torch.Tensor,
|
||||
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
batch_size = len(hidden_states)
|
||||
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
||||
|
||||
if is_hidden_states_tensor:
|
||||
assert hidden_states.ndim == 4
|
||||
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
||||
|
||||
device = hidden_states[0].device
|
||||
|
||||
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
ref_image_hidden_states,
|
||||
img_mask,
|
||||
ref_img_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||
|
||||
(
|
||||
context_rotary_emb,
|
||||
ref_img_rotary_emb,
|
||||
noise_rotary_emb,
|
||||
rotary_emb,
|
||||
encoder_seq_lengths,
|
||||
seq_lengths,
|
||||
) = self.rope_embedder(
|
||||
freqs_cis,
|
||||
text_attention_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
device,
|
||||
)
|
||||
|
||||
# 2. Context refinement
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
hidden_states,
|
||||
ref_image_hidden_states,
|
||||
img_mask,
|
||||
ref_img_mask,
|
||||
noise_rotary_emb,
|
||||
ref_img_rotary_emb,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
temb,
|
||||
)
|
||||
|
||||
# 3. Joint Transformer blocks
|
||||
max_seq_len = max(seq_lengths)
|
||||
|
||||
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
||||
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
||||
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
||||
attention_mask[i, :seq_len] = True
|
||||
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
|
||||
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
||||
|
||||
hidden_states = joint_hidden_states
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer, hidden_states, attention_mask, rotary_emb, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
|
||||
# 4. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
p = self.config.patch_size
|
||||
output = []
|
||||
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
||||
height, width = img_size
|
||||
output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
|
||||
if is_hidden_states_tensor:
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return output
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
|
@ -0,0 +1,718 @@
|
|||
"""
|
||||
OmniGen2 Diffusion Pipeline
|
||||
|
||||
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
import inspect
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import PIL.Image
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
from .models.transformers import OmniGen2Transformer2DModel
|
||||
from .models.transformers.repo import OmniGen2RotaryPosEmbed
|
||||
from .image_processor import OmniGen2ImageProcessor
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class FMPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for OmniGen2 pipeline.
|
||||
|
||||
Args:
|
||||
images (Union[List[PIL.Image.Image], np.ndarray]):
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape
|
||||
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
||||
"""
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class OmniGen2Pipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using OmniGen2.
|
||||
|
||||
This pipeline implements a text-to-image generation model that uses:
|
||||
- Qwen2.5-VL for text encoding
|
||||
- A custom transformer architecture for image generation
|
||||
- VAE for image encoding/decoding
|
||||
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
||||
|
||||
Args:
|
||||
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
||||
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
||||
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
||||
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
||||
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "mllm->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: OmniGen2Transformer2DModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
mllm: Qwen2_5_VLForConditionalGeneration,
|
||||
processor,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the OmniGen2 pipeline.
|
||||
|
||||
Args:
|
||||
transformer: The transformer model for image generation.
|
||||
vae: The VAE model for image encoding/decoding.
|
||||
scheduler: The scheduler for noise scheduling.
|
||||
text_encoder: The text encoder model.
|
||||
tokenizer: The tokenizer for text processing.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
mllm=mllm,
|
||||
processor=processor
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
||||
self.default_sample_size = 128
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Prepare the initial latents for the diffusion process.
|
||||
|
||||
Args:
|
||||
batch_size: The number of images to generate.
|
||||
num_channels_latents: The number of channels in the latent space.
|
||||
height: The height of the generated image.
|
||||
width: The width of the generated image.
|
||||
dtype: The data type of the latents.
|
||||
device: The device to place the latents on.
|
||||
generator: The random number generator to use.
|
||||
latents: Optional pre-computed latents to use instead of random initialization.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: The prepared latents tensor.
|
||||
"""
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Encode an image into the VAE latent space.
|
||||
|
||||
Args:
|
||||
img: The input image tensor to encode.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: The encoded latent representation.
|
||||
"""
|
||||
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
||||
if self.vae.config.shift_factor is not None:
|
||||
z0 = z0 - self.vae.config.shift_factor
|
||||
if self.vae.config.scaling_factor is not None:
|
||||
z0 = z0 * self.vae.config.scaling_factor
|
||||
z0 = z0.to(dtype=self.vae.dtype)
|
||||
return z0
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
||||
batch_size: int,
|
||||
num_images_per_prompt: int,
|
||||
max_pixels: int,
|
||||
max_side_length: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> List[Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Prepare input images for processing by encoding them into the VAE latent space.
|
||||
|
||||
Args:
|
||||
images: Single image or list of images to process.
|
||||
batch_size: The number of images to generate per prompt.
|
||||
num_images_per_prompt: The number of images to generate for each prompt.
|
||||
device: The device to place the encoded latents on.
|
||||
dtype: The data type of the encoded latents.
|
||||
|
||||
Returns:
|
||||
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
||||
"""
|
||||
if batch_size == 1:
|
||||
images = [images]
|
||||
latents = []
|
||||
for i, img in enumerate(images):
|
||||
if img is not None and len(img) > 0:
|
||||
ref_latents = []
|
||||
for j, img_j in enumerate(img):
|
||||
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
||||
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
||||
else:
|
||||
ref_latents = None
|
||||
for _ in range(num_images_per_prompt):
|
||||
latents.append(ref_latents)
|
||||
|
||||
return latents
|
||||
|
||||
def _get_qwen2_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get prompt embeddings from the Qwen2 text encoder.
|
||||
|
||||
Args:
|
||||
prompt: The prompt or list of prompts to encode.
|
||||
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
||||
max_sequence_length: Maximum sequence length for tokenization.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The prompt embeddings tensor
|
||||
- The attention mask tensor
|
||||
|
||||
Raises:
|
||||
Warning: If the input text is truncated due to sequence length limitations.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
# text_inputs = self.processor.tokenizer(
|
||||
# prompt,
|
||||
# padding="max_length",
|
||||
# max_length=max_sequence_length,
|
||||
# truncation=True,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
text_inputs = self.processor.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = self.mllm(
|
||||
text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-1]
|
||||
|
||||
if self.mllm is not None:
|
||||
dtype = self.mllm.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def _apply_chat_template(self, prompt: str):
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that generates high-quality images based on user instructions.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
||||
return prompt
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Lumina-T2I, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
||||
max_sequence_length (`int`, defaults to `256`):
|
||||
Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# Get negative embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
# Normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
||||
batch_size * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def text_guidance_scale(self):
|
||||
return self._text_guidance_scale
|
||||
|
||||
@property
|
||||
def image_guidance_scale(self):
|
||||
return self._image_guidance_scale
|
||||
|
||||
@property
|
||||
def cfg_range(self):
|
||||
return self._cfg_range
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_sequence_length: Optional[int] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
input_images: Optional[List[PIL.Image.Image]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: int = 2048 * 2048,
|
||||
max_input_image_side_length: int = 2048,
|
||||
align_res: bool = True,
|
||||
num_inference_steps: int = 28,
|
||||
text_guidance_scale: float = 4.0,
|
||||
image_guidance_scale: float = 1.0,
|
||||
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
timesteps: List[int] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
verbose: bool = False,
|
||||
step_func=None,
|
||||
):
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
self._text_guidance_scale = text_guidance_scale
|
||||
self._image_guidance_scale = image_guidance_scale
|
||||
self._cfg_range = cfg_range
|
||||
self._attention_kwargs = attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
self.text_guidance_scale > 1.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
dtype = self.vae.dtype
|
||||
# 3. Prepare control image
|
||||
ref_latents = self.prepare_image(
|
||||
images=input_images,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_pixels=max_pixels,
|
||||
max_side_length=max_input_image_side_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if input_images is None:
|
||||
input_images = []
|
||||
|
||||
if len(input_images) == 1 and align_res:
|
||||
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
||||
ori_width, ori_height = width, height
|
||||
else:
|
||||
ori_width, ori_height = width, height
|
||||
|
||||
cur_pixels = height * width
|
||||
ratio = (max_pixels / cur_pixels) ** 0.5
|
||||
ratio = min(ratio, 1.0)
|
||||
|
||||
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
||||
|
||||
if len(input_images) == 0:
|
||||
self._image_guidance_scale = 1
|
||||
|
||||
# 4. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
||||
self.transformer.config.axes_dim_rope,
|
||||
self.transformer.config.axes_lens,
|
||||
theta=10000,
|
||||
)
|
||||
|
||||
image = self.processing(
|
||||
latents=latents,
|
||||
ref_latents=ref_latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
step_func=step_func,
|
||||
)
|
||||
|
||||
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return image
|
||||
else:
|
||||
return FMPipelineOutput(images=image)
|
||||
|
||||
def processing(
|
||||
self,
|
||||
latents,
|
||||
ref_latents,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
num_inference_steps,
|
||||
timesteps,
|
||||
device,
|
||||
dtype,
|
||||
verbose,
|
||||
step_func=None
|
||||
):
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
num_tokens=latents.shape[-2] * latents.shape[-1]
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_pred = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
)
|
||||
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
||||
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
||||
|
||||
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
||||
model_pred_ref = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
)
|
||||
|
||||
if image_guidance_scale != 1:
|
||||
model_pred_uncond = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
)
|
||||
else:
|
||||
model_pred_uncond = torch.zeros_like(model_pred)
|
||||
|
||||
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
||||
text_guidance_scale * (model_pred - model_pred_ref)
|
||||
elif text_guidance_scale > 1.0:
|
||||
model_pred_uncond = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
)
|
||||
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
latents = latents.to(dtype=dtype)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if step_func is not None:
|
||||
step_func(i, self._num_timesteps)
|
||||
|
||||
latents = latents.to(dtype=dtype)
|
||||
if self.vae.config.scaling_factor is not None:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
if self.vae.config.shift_factor is not None:
|
||||
latents = latents + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
return image
|
||||
|
||||
def predict(
|
||||
self,
|
||||
t,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
prompt_attention_mask,
|
||||
ref_image_hidden_states,
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
|
||||
optional_kwargs = {}
|
||||
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
||||
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
||||
|
||||
model_pred = self.transformer(
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
prompt_attention_mask,
|
||||
**optional_kwargs
|
||||
)
|
||||
return model_pred
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
|
||||
|
||||
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
|
||||
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
|
||||
:param pipeline:
|
||||
:param prompt:
|
||||
:param negative_prompt:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
max_length = pipeline.tokenizer.model_max_length
|
||||
|
||||
# simple way to determine length of tokens
|
||||
# count_prompt = len(prompt.split(" "))
|
||||
# count_negative_prompt = len(negative_prompt.split(" "))
|
||||
|
||||
# create the tensor based on which prompt is longer
|
||||
# if count_prompt >= count_negative_prompt:
|
||||
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
|
||||
# input_ids = pipeline.tokenizer(prompt, padding="max_length",
|
||||
# max_length=pipeline.tokenizer.model_max_length,
|
||||
# truncation=True,
|
||||
# return_tensors="pt",).input_ids.to(device)
|
||||
shape_max_length = input_ids.shape[-1]
|
||||
|
||||
if negative_prompt is not None:
|
||||
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
|
||||
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
# else:
|
||||
# negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
|
||||
# shape_max_length = negative_ids.shape[-1]
|
||||
# input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
|
||||
# max_length=shape_max_length).input_ids.to(device)
|
||||
|
||||
concat_embeds = []
|
||||
neg_embeds = []
|
||||
for i in range(0, shape_max_length, max_length):
|
||||
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
||||
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
|
||||
attention_mask=attention_mask)[0])
|
||||
|
||||
if negative_prompt is not None:
|
||||
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
||||
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
|
||||
attention_mask=attention_mask)[0])
|
||||
|
||||
concat_embeds = torch.cat(concat_embeds, dim=1)
|
||||
|
||||
if negative_prompt is not None:
|
||||
neg_embeds = torch.cat(neg_embeds, dim=1)
|
||||
else:
|
||||
neg_embeds = None
|
||||
|
||||
return concat_embeds, neg_embeds
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -61,7 +61,7 @@ def task_specific_kwargs(p, model):
|
|||
p.width, p.height = p.width // vae_scale_factor * vae_scale_factor, p.height // vae_scale_factor * vae_scale_factor
|
||||
task_args['max_area'] = max_area
|
||||
task_args['width'], task_args['height'] = p.width, p.height
|
||||
if model.__class__.__name__ == 'OmniGenPipeline':
|
||||
elif model.__class__.__name__ == 'OmniGenPipeline' or model.__class__.__name__ == 'OmniGen2Pipeline':
|
||||
p.width, p.height = 16 * math.ceil(p.init_images[0].width / 16), 16 * math.ceil(p.init_images[0].height / 16)
|
||||
task_args = {
|
||||
'width': p.width,
|
||||
|
|
@ -285,7 +285,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
kwargs['output_type'] = 'np' # only set latent if model has vae
|
||||
|
||||
# model specific
|
||||
if 'Kandinsky' in model.__class__.__name__:
|
||||
if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__:
|
||||
kwargs['output_type'] = 'np' # only set latent if model has vae
|
||||
if 'StableCascade' in model.__class__.__name__:
|
||||
kwargs.pop("guidance_scale") # remove
|
||||
|
|
@ -305,8 +305,6 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
args['control_strength'] = p.denoising_strength
|
||||
args['width'] = p.width
|
||||
args['height'] = p.height
|
||||
if 'Cosmos2' in model.__class__.__name__:
|
||||
kwargs['output_type'] = 'np' # cosmos uses wan-vae which is weird
|
||||
# set callbacks
|
||||
if 'prior_callback_steps' in possible: # Wuerstchen / Cascade
|
||||
args['prior_callback_steps'] = 1
|
||||
|
|
|
|||
|
|
@ -88,6 +88,9 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
|
|||
if 'omnigen' in f.lower():
|
||||
guess = 'OmniGen'
|
||||
pipeline = 'custom'
|
||||
if 'omnigen2' in f.lower():
|
||||
guess = 'OmniGen2'
|
||||
pipeline = 'custom'
|
||||
if 'sd3' in f.lower():
|
||||
guess = 'Stable Diffusion 3'
|
||||
if 'hidream' in f.lower():
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ pipe_switch_task_exclude = [
|
|||
'InstantIRPipeline',
|
||||
'LTXConditionPipeline',
|
||||
'OmniGenPipeline',
|
||||
'OmniGen2Pipeline',
|
||||
'PhotoMakerStableDiffusionXLPipeline',
|
||||
'PixelSmithXLPipeline',
|
||||
'StableDiffusion3ControlNetPipeline',
|
||||
|
|
@ -364,6 +365,10 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
from modules.model_meissonic import load_meissonic
|
||||
sd_model = load_meissonic(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['OmniGen2']: # forced pipeline
|
||||
from modules.model_omnigen2 import load_omnigen2
|
||||
sd_model = load_omnigen2(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['OmniGen']: # forced pipeline
|
||||
from modules.model_omnigen import load_omnigen
|
||||
sd_model = load_omnigen(checkpoint_info, diffusers_load_config)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from modules.timer import process as process_timer
|
|||
|
||||
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
|
||||
debug_move = log.trace if debug else lambda *args, **kwargs: None
|
||||
offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4', 'cosmos', 'chroma']
|
||||
offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma']
|
||||
offload_post = ['h1']
|
||||
offload_hook_instance = None
|
||||
balanced_offload_exclude = ['CogView4Pipeline']
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t
|
|||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
|
||||
flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma']
|
||||
flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen2']
|
||||
warned = False
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
|
|
|||
|
|
@ -522,11 +522,11 @@ options_templates.update(options_section(("quantization", "Quantization Settings
|
|||
"sdnq_quantize_weights_mode": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": sdnq_quant_modes, "visible": native}),
|
||||
"sdnq_quantize_weights_mode_te": OptionInfo("default", "Quantization type for Text Encoders", gr.Dropdown, {"choices": ['default'] + sdnq_quant_modes, "visible": native}),
|
||||
"sdnq_quantize_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": native}),
|
||||
"sdnq_quantize_conv_layers": OptionInfo(False, "Quantize the convolutional layers", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_quantize_conv_layers": OptionInfo(False, "Quantize convolutional layers", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_dequantize_compile": OptionInfo(devices.has_triton(), "Dequantize using torch.compile", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_use_quantized_matmul": OptionInfo(False, "Use Quantized MatMul", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_use_quantized_matmul_conv": OptionInfo(False, "Use Quantized MatMul with convolutional layers", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_quantize_with_gpu": OptionInfo(True, "Quantize with the GPU", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_use_quantized_matmul": OptionInfo(False, "Use quantized MatMul", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_use_quantized_matmul_conv": OptionInfo(False, "Use quantized MatMul with convolutional layers", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_quantize_with_gpu": OptionInfo(True, "Quantize using GPU", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_dequantize_fp32": OptionInfo(False, "Dequantize using full precision", gr.Checkbox, {"visible": native}),
|
||||
"sdnq_quantize_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox, {"visible": native}),
|
||||
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit c1ae36ab4c2197487306c5c9fe4e328a038d1367
|
||||
Subproject commit 711b61ebfdb8cb09b06b4f8c627ae97519c6f74c
|
||||
Loading…
Reference in New Issue