parent
44070d0fb7
commit
0491256d16
|
|
@ -57,9 +57,11 @@ class DepthModel():
|
||||||
if half_precision and self.device == torch.device("cuda"):
|
if half_precision and self.device == torch.device("cuda"):
|
||||||
self.midas_model = self.midas_model.to(memory_format=torch.channels_last)
|
self.midas_model = self.midas_model.to(memory_format=torch.channels_last)
|
||||||
self.midas_model = self.midas_model.half()
|
self.midas_model = self.midas_model.half()
|
||||||
self.midas_model.to(self.device)
|
self.midas_model.to(self.device)
|
||||||
|
else:
|
||||||
|
self.midas_model.to(self.device)
|
||||||
|
|
||||||
def predict(self, prev_img_cv2, anim_args) -> torch.Tensor:
|
def predict(self, prev_img_cv2, anim_args, half_precision) -> torch.Tensor:
|
||||||
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
|
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
|
||||||
|
|
||||||
# predict depth with AdaBins
|
# predict depth with AdaBins
|
||||||
|
|
@ -109,7 +111,8 @@ class DepthModel():
|
||||||
sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0)
|
sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0)
|
||||||
if self.device == torch.device("cuda"):
|
if self.device == torch.device("cuda"):
|
||||||
sample = sample.to(memory_format=torch.channels_last)
|
sample = sample.to(memory_format=torch.channels_last)
|
||||||
sample = sample.half()
|
if half_precision:
|
||||||
|
sample = sample.half()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
midas_depth = self.midas_model.forward(sample)
|
midas_depth = self.midas_model.forward(sample)
|
||||||
midas_depth = torch.nn.functional.interpolate(
|
midas_depth = torch.nn.functional.interpolate(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue