56 lines
2.4 KiB
Python
56 lines
2.4 KiB
Python
import torch
|
|
import cv2
|
|
import os
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
from .general_utils import download_file_with_checksum
|
|
from PIL import Image
|
|
from leres.lib.multi_depth_model_woauxi import RelDepthModel
|
|
from leres.lib.net_tools import load_ckpt
|
|
|
|
class LeReSDepth:
|
|
def __init__(self, width=448, height=448, models_path=None, checkpoint_name='res101.pth', backbone='resnext101'):
|
|
self.width = width
|
|
self.height = height
|
|
self.models_path = models_path
|
|
self.checkpoint_name = checkpoint_name
|
|
self.backbone = backbone
|
|
|
|
download_file_with_checksum(url='https://cloudstor.aarnet.edu.au/plus/s/lTIJF4vrvHCAI31/download', expected_checksum='7fdc870ae6568cb28d56700d0be8fc45541e09cea7c4f84f01ab47de434cfb7463cacae699ad19fe40ee921849f9760dedf5e0dec04a62db94e169cf203f55b1', dest_folder=models_path, dest_filename=self.checkpoint_name)
|
|
|
|
self.depth_model = RelDepthModel(backbone=self.backbone)
|
|
self.depth_model.eval()
|
|
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.depth_model.to(self.DEVICE)
|
|
load_ckpt(os.path.join(self.models_path, self.checkpoint_name), self.depth_model, None, None)
|
|
|
|
@staticmethod
|
|
def scale_torch(img):
|
|
if len(img.shape) == 2:
|
|
img = img[np.newaxis, :, :]
|
|
if img.shape[2] == 3:
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225))])
|
|
img = transform(img)
|
|
else:
|
|
img = img.astype(np.float32)
|
|
img = torch.from_numpy(img)
|
|
return img
|
|
|
|
def predict(self, image):
|
|
resized_image = cv2.resize(image, (self.width, self.height))
|
|
img_torch = self.scale_torch(resized_image)[None, :, :, :]
|
|
pred_depth = self.depth_model.inference(img_torch).cpu().numpy().squeeze()
|
|
pred_depth_ori = cv2.resize(pred_depth, (image.shape[1], image.shape[0]))
|
|
return torch.from_numpy(pred_depth_ori).unsqueeze(0).to(self.DEVICE)
|
|
|
|
def save_raw_depth(self, depth, filepath):
|
|
depth_normalized = (depth / depth.max() * 60000).astype(np.uint16)
|
|
cv2.imwrite(filepath, depth_normalized)
|
|
|
|
def to(self, device):
|
|
self.DEVICE = device
|
|
self.depth_model = self.depth_model.to(device)
|
|
|
|
def delete(self):
|
|
del self.depth_model |