diff --git a/ia_sam_manager.py b/ia_sam_manager.py index 0c27f7e..20a42ee 100644 --- a/ia_sam_manager.py +++ b/ia_sam_manager.py @@ -21,6 +21,20 @@ from segment_anything_hq import SamPredictor as SamPredictorHQ from segment_anything_hq import sam_model_registry as sam_model_registry_hq +def check_bfloat16_support() -> bool: + if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device()) + if compute_capability[0] >= 8: + ia_logging.debug("The CUDA device supports bfloat16") + return True + else: + ia_logging.debug("The CUDA device does not support bfloat16") + return False + else: + ia_logging.debug("CUDA is not available") + return False + + def partial_from_end(func, /, *fixed_args, **fixed_kwargs): def wrapper(*args, **kwargs): updated_kwargs = {**fixed_kwargs, **kwargs} diff --git a/inpalib/masklib.py b/inpalib/masklib.py index 8f7fc9f..9d887f3 100644 --- a/inpalib/masklib.py +++ b/inpalib/masklib.py @@ -24,7 +24,7 @@ def check_inputs_create_mask_image( mask: Union[np.ndarray, Image.Image], sam_masks: List[Dict[str, Any]], ignore_black_chk: bool = True, - ) -> None: +) -> None: """Check create mask image inputs. Args: @@ -70,7 +70,7 @@ def create_mask_image( mask: Union[np.ndarray, Image.Image], sam_masks: List[Dict[str, Any]], ignore_black_chk: bool = True, - ) -> np.ndarray: +) -> np.ndarray: """Create mask image. Args: diff --git a/inpalib/samlib.py b/inpalib/samlib.py index e1921ec..898b60d 100644 --- a/inpalib/samlib.py +++ b/inpalib/samlib.py @@ -16,7 +16,7 @@ if inpa_basedir not in sys.path: from ia_file_manager import ia_file_manager # noqa: E402 from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402 from ia_logging import ia_logging # noqa: E402 -from ia_sam_manager import get_sam_mask_generator # noqa: E402 +from ia_sam_manager import check_bfloat16_support, get_sam_mask_generator # noqa: E402 from ia_ui_items import get_sam_model_ids # noqa: E402 @@ -139,7 +139,8 @@ def generate_sam_masks( if "sam2_" in sam_id: device = "cuda" if torch.cuda.is_available() else "cpu" - with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16): + torch_dtype = torch.bfloat16 if check_bfloat16_support() else torch.float16 + with torch.inference_mode(), torch.autocast(device, dtype=torch_dtype): sam_masks = sam_mask_generator.generate(input_image) else: sam_masks = sam_mask_generator.generate(input_image)