mirror of https://github.com/vladmandic/automatic
91 lines
4.9 KiB
Python
91 lines
4.9 KiB
Python
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from modules import devices, masking
|
|
from modules.shared import opts
|
|
|
|
|
|
class LotusDetector:
|
|
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
|
|
self.unet = unet
|
|
self.vae = vae
|
|
self.text_encoder = text_encoder
|
|
self.tokenizer = tokenizer
|
|
self.scheduler = scheduler
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_or_path="jingheya/lotus-depth-g-v2-1-disparity", cache_dir=None, local_files_only=False, **kwargs):
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
load_kwargs = dict(cache_dir=cache_dir, local_files_only=local_files_only)
|
|
unet = UNet2DConditionModel.from_pretrained(pretrained_model_or_path, subfolder="unet", **load_kwargs)
|
|
vae = AutoencoderKL.from_pretrained(pretrained_model_or_path, subfolder="vae", **load_kwargs)
|
|
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_or_path, subfolder="text_encoder", **load_kwargs)
|
|
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_or_path, subfolder="tokenizer", **load_kwargs)
|
|
scheduler = DDPMScheduler.from_pretrained(pretrained_model_or_path, subfolder="scheduler", **load_kwargs)
|
|
return cls(unet, vae, text_encoder, tokenizer, scheduler)
|
|
|
|
def _to(self, device):
|
|
self.unet.to(device)
|
|
self.vae.to(device)
|
|
self.text_encoder.to(device)
|
|
|
|
def __call__(self, image, color_map="none", output_type="pil", **kwargs):
|
|
self._to(devices.device)
|
|
if isinstance(image, np.ndarray):
|
|
image = Image.fromarray(image)
|
|
orig_w, orig_h = image.size
|
|
# Resize to processing resolution (768) while maintaining aspect ratio
|
|
processing_res = 768
|
|
scale = processing_res / max(orig_w, orig_h)
|
|
new_w = int(round(orig_w * scale / 8) * 8)
|
|
new_h = int(round(orig_h * scale / 8) * 8)
|
|
image_resized = image.resize((new_w, new_h), Image.Resampling.BILINEAR)
|
|
# Convert to tensor [-1, 1], matching model dtype
|
|
dtype = next(self.vae.parameters()).dtype
|
|
rgb = torch.from_numpy(np.array(image_resized).astype(np.float32) / 127.5 - 1.0)
|
|
rgb = rgb.permute(2, 0, 1).unsqueeze(0).to(devices.device, dtype=dtype)
|
|
# Encode prompt (empty string for unconditional depth)
|
|
text_inputs = self.tokenizer("", padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
|
prompt_embeds = self.text_encoder(text_inputs.input_ids.to(devices.device))[0]
|
|
# Task embedding for depth prediction: sin/cos encoding of [1, 0] → shape [1, 4]
|
|
task_emb = torch.tensor([1, 0], dtype=dtype, device=devices.device).unsqueeze(0)
|
|
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
|
# Single-step direct prediction (Lotus-G)
|
|
self.scheduler.set_timesteps(1, device=devices.device)
|
|
timestep = self.scheduler.timesteps[0:1]
|
|
with devices.inference_context():
|
|
# Encode RGB to latent space
|
|
rgb_latents = self.vae.encode(rgb).latent_dist.sample() * self.vae.config.scaling_factor
|
|
# Random noise latents (half channels of UNet input)
|
|
noise_latents = torch.randn(rgb_latents.shape, device=devices.device, dtype=dtype)
|
|
# Concatenate along channel dim: [rgb_latents, noise_latents]
|
|
latent_input = torch.cat([rgb_latents, noise_latents], dim=1)
|
|
# UNet forward pass with task embedding as class_labels
|
|
prediction = self.unet(latent_input, timestep, encoder_hidden_states=prompt_embeds, class_labels=task_emb).sample
|
|
# Decode prediction
|
|
prediction = prediction / self.vae.config.scaling_factor
|
|
decoded = self.vae.decode(prediction).sample
|
|
if opts.control_move_processor:
|
|
self._to("cpu")
|
|
# Convert from [-1,1] to [0,1]
|
|
depth = (decoded.squeeze(0).permute(1, 2, 0).float().cpu().numpy() + 1.0) * 0.5
|
|
depth = depth.mean(axis=2) if depth.ndim == 3 else depth
|
|
depth = depth - depth.min()
|
|
depth_max = depth.max()
|
|
if depth_max > 0:
|
|
depth = depth / depth_max
|
|
depth = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
|
# Resize back to original
|
|
if depth.shape[:2] != (orig_h, orig_w):
|
|
depth = np.array(Image.fromarray(depth).resize((orig_w, orig_h), Image.Resampling.BILINEAR))
|
|
if color_map != "none":
|
|
colormap_key = color_map if color_map in masking.COLORMAP else "inferno"
|
|
depth = cv2.applyColorMap(depth, masking.COLORMAP.index(colormap_key))[:, :, ::-1]
|
|
if output_type == "pil":
|
|
mode = "RGB" if depth.ndim == 3 else "L"
|
|
depth = Image.fromarray(depth, mode=mode)
|
|
return depth
|