Update depth.py

fixing not able to use fp32
pull/111/head
Funofabot 2022-11-24 05:56:08 -07:00 committed by GitHub
parent 44070d0fb7
commit 0491256d16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 3 deletions

View File

@ -57,9 +57,11 @@ class DepthModel():
if half_precision and self.device == torch.device("cuda"):
self.midas_model = self.midas_model.to(memory_format=torch.channels_last)
self.midas_model = self.midas_model.half()
self.midas_model.to(self.device)
self.midas_model.to(self.device)
else:
self.midas_model.to(self.device)
def predict(self, prev_img_cv2, anim_args) -> torch.Tensor:
def predict(self, prev_img_cv2, anim_args, half_precision) -> torch.Tensor:
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
# predict depth with AdaBins
@ -109,7 +111,8 @@ class DepthModel():
sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0)
if self.device == torch.device("cuda"):
sample = sample.to(memory_format=torch.channels_last)
sample = sample.half()
if half_precision:
sample = sample.half()
with torch.no_grad():
midas_depth = self.midas_model.forward(sample)
midas_depth = torch.nn.functional.interpolate(