automatic/pipelines/wan/wan_image.py

120 lines
5.4 KiB
Python

from typing import Any, Callable, Dict, List, Optional, Union
import torch
import diffusers
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from modules import devices
class WanImagePipeline(diffusers.WanPipeline):
def __call__(
self,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] | None = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
strength: float = 0.3, # new
image: PipelineImageInput = None, # new
):
# get img2img timesteps
self.scheduler.set_timesteps(num_inference_steps, device=devices.device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
# monkey patch original pipeline
self.scheduler.timesteps = timesteps
# self.scheduler._step_index = 0
self.scheduler.orig_set_timesteps = self.scheduler.set_timesteps
self.scheduler.set_timesteps = lambda *args, **kwargs: None
# prepare latents
latents = self.img2img_prepare_latents(
image=image,
timesteps=timesteps,
dtype=devices.dtype,
device=devices.device,
generator=generator,
)
# call original pipeline
result = super().__call__( # pylint: disable=no-member
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
guidance_scale_2=guidance_scale_2,
num_videos_per_prompt=num_videos_per_prompt,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
attention_kwargs=attention_kwargs,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
# un-monkey patch original pipeline
self.scheduler.set_timesteps = self.scheduler.orig_set_timesteps
return result
def get_timesteps(self, num_inference_steps, strength):
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
# self.scheduler.set_begin_index(t_start * self.scheduler.order)
self.scheduler.set_begin_index(0)
return timesteps, num_inference_steps - t_start
def img2img_prepare_latents(
self,
image: torch.Tensor = None,
timesteps: torch.Tensor = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
) -> torch.Tensor:
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if isinstance(image, list):
image = image[0] # ignore batch for now
video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
image_tensor = video_processor.preprocess(image, None, None) # convert PIL to [B, C, H, W] # channels may need rearrange
image_tensor = image_tensor.squeeze(0).to(device=device, dtype=dtype)
image_tensor = image_tensor[None, :, None, :, :] # expand before encode to [B, C, N, H, W]
encoder_output = self.vae.encode(image_tensor)
# init_latents = encoder_output.latent_dist.mode() # argmax or sample?
init_latents = encoder_output.latent_dist.sample(generator)
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(1, self.vae.config.z_dim, 1, 1, 1)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(1, self.vae.config.z_dim, 1, 1, 1)
init_latents = ((init_latents.float() - latents_mean) * latents_std).to(dtype) # normalized to standard distribution range
init_noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
init_timestep = timesteps[:1]
noised_latents = self.scheduler.add_noise(init_latents, init_noise, init_timestep)
return noised_latents