mirror of https://github.com/vladmandic/automatic
45 lines
1.7 KiB
Python
45 lines
1.7 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from modules import devices
|
|
from modules.shared import opts
|
|
|
|
|
|
class DSINEDetector:
|
|
def __init__(self, predictor):
|
|
self.predictor = predictor
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_or_path="hugoycj/DSINE-hub", cache_dir=None, local_files_only=False):
|
|
from installer import install
|
|
install('geffnet', quiet=True)
|
|
hub_dir = os.path.join(cache_dir, 'torch_hub') if cache_dir else None
|
|
old_hub_dir = torch.hub.get_dir()
|
|
if hub_dir:
|
|
os.makedirs(hub_dir, exist_ok=True)
|
|
torch.hub.set_dir(hub_dir)
|
|
try:
|
|
predictor = torch.hub.load(pretrained_model_or_path, "DSINE", trust_repo=True, source="github")
|
|
finally:
|
|
torch.hub.set_dir(old_hub_dir)
|
|
# Override hardcoded cuda device with project device
|
|
predictor.device = devices.device
|
|
predictor.model = predictor.model.to(devices.device).eval()
|
|
return cls(predictor)
|
|
|
|
def __call__(self, image, output_type="pil", **kwargs):
|
|
self.predictor.device = devices.device
|
|
self.predictor.model.to(devices.device)
|
|
if isinstance(image, np.ndarray):
|
|
image = Image.fromarray(image)
|
|
with devices.inference_context():
|
|
normals = self.predictor.infer_pil(image)
|
|
if opts.control_move_processor:
|
|
self.predictor.model.to("cpu")
|
|
normals = normals[0].permute(1, 2, 0).cpu().numpy()
|
|
normals = ((normals + 1.0) * 0.5 * 255.0).clip(0, 255).astype(np.uint8)
|
|
if output_type == "pil":
|
|
normals = Image.fromarray(normals)
|
|
return normals
|