mirror of https://github.com/vladmandic/automatic
120 lines
5.4 KiB
Python
120 lines
5.4 KiB
Python
from typing import Any, Callable, Dict, List, Optional, Union
|
|
import torch
|
|
import diffusers
|
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
from diffusers.image_processor import PipelineImageInput
|
|
|
|
from modules import devices
|
|
|
|
|
|
class WanImagePipeline(diffusers.WanPipeline):
|
|
def __call__(
|
|
self,
|
|
prompt: Union[str, List[str]] = None,
|
|
negative_prompt: Union[str, List[str]] = None,
|
|
height: int = 480,
|
|
width: int = 832,
|
|
num_frames: int = 81,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 5.0,
|
|
guidance_scale_2: Optional[float] = None,
|
|
num_videos_per_prompt: Optional[int] = 1,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
latents: Optional[torch.Tensor] = None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
output_type: Optional[str] = "np",
|
|
return_dict: bool = True,
|
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
max_sequence_length: int = 512,
|
|
strength: float = 0.3, # new
|
|
image: PipelineImageInput = None, # new
|
|
):
|
|
# get img2img timesteps
|
|
self.scheduler.set_timesteps(num_inference_steps, device=devices.device)
|
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
|
# monkey patch original pipeline
|
|
self.scheduler.timesteps = timesteps
|
|
# self.scheduler._step_index = 0
|
|
self.scheduler.orig_set_timesteps = self.scheduler.set_timesteps
|
|
self.scheduler.set_timesteps = lambda *args, **kwargs: None
|
|
|
|
# prepare latents
|
|
latents = self.img2img_prepare_latents(
|
|
image=image,
|
|
timesteps=timesteps,
|
|
dtype=devices.dtype,
|
|
device=devices.device,
|
|
generator=generator,
|
|
)
|
|
|
|
# call original pipeline
|
|
result = super().__call__( # pylint: disable=no-member
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
height=height,
|
|
width=width,
|
|
num_frames=num_frames,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
guidance_scale_2=guidance_scale_2,
|
|
num_videos_per_prompt=num_videos_per_prompt,
|
|
generator=generator,
|
|
latents=latents,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
output_type=output_type,
|
|
return_dict=return_dict,
|
|
attention_kwargs=attention_kwargs,
|
|
callback_on_step_end=callback_on_step_end,
|
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
|
max_sequence_length=max_sequence_length,
|
|
)
|
|
|
|
# un-monkey patch original pipeline
|
|
self.scheduler.set_timesteps = self.scheduler.orig_set_timesteps
|
|
return result
|
|
|
|
def get_timesteps(self, num_inference_steps, strength):
|
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
|
t_start = max(num_inference_steps - init_timestep, 0)
|
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
|
if hasattr(self.scheduler, "set_begin_index"):
|
|
# self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
|
self.scheduler.set_begin_index(0)
|
|
return timesteps, num_inference_steps - t_start
|
|
|
|
def img2img_prepare_latents(
|
|
self,
|
|
image: torch.Tensor = None,
|
|
timesteps: torch.Tensor = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
) -> torch.Tensor:
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
from diffusers.video_processor import VideoProcessor
|
|
|
|
if isinstance(image, list):
|
|
image = image[0] # ignore batch for now
|
|
|
|
video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
|
image_tensor = video_processor.preprocess(image, None, None) # convert PIL to [B, C, H, W] # channels may need rearrange
|
|
image_tensor = image_tensor.squeeze(0).to(device=device, dtype=dtype)
|
|
image_tensor = image_tensor[None, :, None, :, :] # expand before encode to [B, C, N, H, W]
|
|
encoder_output = self.vae.encode(image_tensor)
|
|
# init_latents = encoder_output.latent_dist.mode() # argmax or sample?
|
|
init_latents = encoder_output.latent_dist.sample(generator)
|
|
|
|
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(1, self.vae.config.z_dim, 1, 1, 1)
|
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(1, self.vae.config.z_dim, 1, 1, 1)
|
|
init_latents = ((init_latents.float() - latents_mean) * latents_std).to(dtype) # normalized to standard distribution range
|
|
|
|
init_noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
|
init_timestep = timesteps[:1]
|
|
noised_latents = self.scheduler.add_noise(init_latents, init_noise, init_timestep)
|
|
|
|
return noised_latents
|