mirror of https://github.com/vladmandic/automatic
merge: modules/control/proc/marigold_normals/__init__.py
parent
de5b9f27f2
commit
0bdbf300ac
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from modules import devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class MarigoldNormalsDetector:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path="prs-eth/marigold-normals-v1-1", cache_dir=None, **load_config):
|
||||
from diffusers import MarigoldNormalsPipeline
|
||||
# Load in float32 to avoid NaN from SD.Next global fp16 precision settings
|
||||
model = MarigoldNormalsPipeline.from_pretrained(pretrained_model_or_path, torch_dtype=torch.float32, cache_dir=cache_dir, **load_config)
|
||||
return cls(model)
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, input_image, denoising_steps=4, ensemble_size=4, processing_res=768, match_input_res=True, output_type=None):
|
||||
if isinstance(input_image, np.ndarray):
|
||||
input_image = Image.fromarray(input_image)
|
||||
self.model.to(device=devices.device)
|
||||
res = self.model(
|
||||
input_image,
|
||||
num_inference_steps=denoising_steps,
|
||||
ensemble_size=ensemble_size,
|
||||
processing_resolution=processing_res,
|
||||
match_input_resolution=match_input_res,
|
||||
batch_size=1,
|
||||
output_type="pt",
|
||||
)
|
||||
normal_images = self.model.image_processor.visualize_normals(res.prediction)
|
||||
if opts.control_move_processor:
|
||||
self.model.to("cpu")
|
||||
return normal_images[0]
|
||||
Loading…
Reference in New Issue