merge: modules/control/proc/marigold_normals/__init__.py

pull/4678/head
vladmandic 2026-03-12 14:16:50 +01:00
parent de5b9f27f2
commit 0bdbf300ac
1 changed files with 39 additions and 0 deletions

View File

@ -0,0 +1,39 @@
import torch
import numpy as np
from PIL import Image
from modules import devices
from modules.shared import opts
class MarigoldNormalsDetector:
def __init__(self, model):
self.model = model
@classmethod
def from_pretrained(cls, pretrained_model_or_path="prs-eth/marigold-normals-v1-1", cache_dir=None, **load_config):
from diffusers import MarigoldNormalsPipeline
# Load in float32 to avoid NaN from SD.Next global fp16 precision settings
model = MarigoldNormalsPipeline.from_pretrained(pretrained_model_or_path, torch_dtype=torch.float32, cache_dir=cache_dir, **load_config)
return cls(model)
def to(self, device):
self.model.to(device)
return self
def __call__(self, input_image, denoising_steps=4, ensemble_size=4, processing_res=768, match_input_res=True, output_type=None):
if isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image)
self.model.to(device=devices.device)
res = self.model(
input_image,
num_inference_steps=denoising_steps,
ensemble_size=ensemble_size,
processing_resolution=processing_res,
match_input_resolution=match_input_res,
batch_size=1,
output_type="pt",
)
normal_images = self.model.image_processor.visualize_normals(res.prediction)
if opts.control_move_processor:
self.model.to("cpu")
return normal_images[0]