Fix missing attribute in FastSAM's 'Upsample'

pull/129/merge
Uminosachi 2024-02-18 16:34:30 +09:00
parent a08995a7d2
commit ab57640de4
1 changed files with 9 additions and 1 deletions

View File

@ -5,7 +5,12 @@ from typing import Any, Dict, List
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import ultralytics
if hasattr(ultralytics, "FastSAM"):
from ultralytics import FastSAM as YOLO
else:
from ultralytics import YOLO
class FastSAM:
@ -16,6 +21,9 @@ class FastSAM:
self.model_path = checkpoint
self.model = YOLO(self.model_path)
if not hasattr(torch.nn.Upsample, "recompute_scale_factor"):
torch.nn.Upsample.recompute_scale_factor = None
def to(self, device) -> None:
self.model.to(device)