import math, os, subprocess import cv2 import hashlib import numpy as np import torch import gc import torchvision.transforms as T from einops import rearrange, repeat from PIL import Image from infer import InferenceHelper from midas.dpt_depth import DPTDepthModel from midas.transforms import Resize, NormalizeImage, PrepareForNet import torchvision.transforms.functional as TF from .general_utils import checksum from modules import lowvram, devices from modules.shared import opts DEBUG_MODE = opts.data.get("deforum_debug_mode_enabled", False) class MidasModel: _instance = None def __new__(cls, *args, **kwargs): keep_in_vram = kwargs.get('keep_in_vram', False) if not cls._instance: cls._instance = super().__new__(cls) cls._instance._initialize(*args, **kwargs) elif not keep_in_vram or not hasattr(cls._instance, 'midas_model'): cls._instance._initialize(*args, **kwargs) return cls._instance def _initialize(self, models_path, device, half_precision=True, keep_in_vram=False): self.keep_in_vram = keep_in_vram self.adabins_helper = None self.depth_min = 1000 self.depth_max = -1000 self.device = device model_file = os.path.join(models_path, 'dpt_large-midas-2f21e586.pt') if not os.path.exists(model_file): from basicsr.utils.download_util import load_file_from_url load_file_from_url(r"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", models_path) if checksum(model_file) != "fcc4829e65d00eeed0a38e9001770676535d2e95c8a16965223aba094936e1316d569563552a852d471f310f83f597e8a238987a26a950d667815e08adaebc06": raise Exception(r"Error while downloading dpt_large-midas-2f21e586.pt. Please download from here: https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt and place in: " + models_path) if not self.keep_in_vram or not hasattr(self, 'midas_model'): self.midas_model = DPTDepthModel( path=model_file, backbone="vitl16_384", non_negative=True, ) normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) self.midas_transform = T.Compose([ Resize(384, 384, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC), normalization, PrepareForNet() ]) self.midas_model.eval().to(self.device, memory_format=torch.channels_last if self.device == torch.device("cuda") else None) if half_precision: self.midas_model = self.midas_model.half() def predict(self, prev_img_cv2, midas_weight, half_precision) -> torch.Tensor: w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR)) use_adabins = midas_weight < 1.0 and self.adabins_helper is not None if use_adabins: MAX_ADABINS_AREA = 500000 MIN_ADABINS_AREA = 448 * 448 image_pil_area = w * h scale = math.sqrt(MIN_ADABINS_AREA) / math.sqrt(image_pil_area) depth_input = img_pil.resize((int(w * scale), int(h * scale)), Image.LANCZOS if image_pil_area > MAX_ADABINS_AREA else Image.BICUBIC) try: with torch.no_grad(): _, adabins_depth = self.adabins_helper.predict_pil(depth_input) adabins_depth = adabins_depth.squeeze().cpu().numpy() if image_pil_area != MAX_ADABINS_AREA: adabins_depth = TF.resize(torch.from_numpy(adabins_depth), torch.Size([h, w]), interpolation=TF.InterpolationMode.BICUBIC).numpy() except: print(" exception encountered, falling back to pure MiDaS") use_adabins = False torch.cuda.empty_cache() if self.midas_model is not None: img_midas = prev_img_cv2.astype(np.float32) / 255.0 img_midas_input = self.midas_transform({"image": img_midas})["image"] sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0) if self.device.type == "cuda" or self.device.type == "mps": sample = sample.to(memory_format=torch.channels_last) if half_precision: sample = sample.half() with torch.no_grad(): midas_depth = self.midas_model.forward(sample) midas_depth = torch.nn.functional.interpolate( midas_depth.unsqueeze(1), size=img_midas.shape[:2], mode="bicubic", align_corners=False, ).squeeze().cpu().numpy() torch.cuda.empty_cache() midas_depth = np.subtract(50.0, midas_depth) / 19.0 depth_map = (midas_depth * midas_weight + adabins_depth * (1.0 - midas_weight)) if use_adabins else midas_depth depth_tensor = torch.from_numpy(np.expand_dims(depth_map, axis=0)).squeeze().to(self.device) else: depth_tensor = torch.ones((h, w), device=self.device) return depth_tensor def to_image(self, depth: torch.Tensor): depth = depth.cpu().numpy() depth = np.expand_dims(depth, axis=0) if len(depth.shape) == 2 else depth self.depth_min = min(self.depth_min, depth.min()) self.depth_max = max(self.depth_max, depth.max()) denom = max(1e-8, self.depth_max - self.depth_min) temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c') temp = repeat(temp, 'h w 1 -> h w c', c=3) return Image.fromarray(temp.astype(np.uint8)) def save(self, filename: str, depth: torch.Tensor): self.to_image(depth).save(filename) def to(self, device): self.device = device self.midas_model.to(device) if self.adabins_helper is not None: self.adabins_helper.to(device) gc.collect() torch.cuda.empty_cache() def delete_model(self): del self.midas_model torch.cuda.empty_cache() class AdaBinsModel: _instance = None def __new__(cls, *args, **kwargs): keep_in_vram = kwargs.get('keep_in_vram', True) if cls._instance is None or not keep_in_vram: cls._instance = super().__new__(cls) cls._instance._initialize(*args, **kwargs) return cls._instance def _initialize(self, models_path, keep_in_vram=False): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.keep_in_vram = keep_in_vram if self.keep_in_vram or not hasattr(self, 'adabins_helper'): if not os.path.exists(os.path.join(models_path, 'AdaBins_nyu.pt')): from basicsr.utils.download_util import load_file_from_url load_file_from_url( r"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt", models_path) if checksum(os.path.join(models_path, 'AdaBins_nyu.pt')) != "643db9785c663aca72f66739427642726b03acc6c4c1d3755a4587aa2239962746410d63722d87b49fc73581dbc98ed8e3f7e996ff7b9c0d56d0fbc98e23e41a": raise Exception( r"Error while downloading AdaBins_nyu.pt. Please download from here: https://drive.google.com/file/d/1lvyZZbC9NLcS8a__YPcUP7rDiIpbRpoF and place in: " + models_path) self.adabins_helper = InferenceHelper(models_path=models_path, dataset='nyu', device=self.device) def delete_model(self): del self.adabins_helper torch.cuda.empty_cache() devices.torch_gc()