From 0491256d168cb97c247e754dad83f4e4efb71874 Mon Sep 17 00:00:00 2001 From: Funofabot <117096230+Funofabot@users.noreply.github.com> Date: Thu, 24 Nov 2022 05:56:08 -0700 Subject: [PATCH] Update depth.py fixing not able to use fp32 --- scripts/deforum_helpers/depth.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/deforum_helpers/depth.py b/scripts/deforum_helpers/depth.py index 6784627b..ee3dc6b4 100644 --- a/scripts/deforum_helpers/depth.py +++ b/scripts/deforum_helpers/depth.py @@ -57,9 +57,11 @@ class DepthModel(): if half_precision and self.device == torch.device("cuda"): self.midas_model = self.midas_model.to(memory_format=torch.channels_last) self.midas_model = self.midas_model.half() - self.midas_model.to(self.device) + self.midas_model.to(self.device) + else: + self.midas_model.to(self.device) - def predict(self, prev_img_cv2, anim_args) -> torch.Tensor: + def predict(self, prev_img_cv2, anim_args, half_precision) -> torch.Tensor: w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] # predict depth with AdaBins @@ -109,7 +111,8 @@ class DepthModel(): sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0) if self.device == torch.device("cuda"): sample = sample.to(memory_format=torch.channels_last) - sample = sample.half() + if half_precision: + sample = sample.half() with torch.no_grad(): midas_depth = self.midas_model.forward(sample) midas_depth = torch.nn.functional.interpolate(