Enable bfloat16 for autocast only when CUDA is 8+

main
Uminosachi 2024-08-02 10:03:11 +09:00
parent de9f5f5c71
commit 39d647db54
3 changed files with 19 additions and 4 deletions

View File

@ -21,6 +21,20 @@ from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
def check_bfloat16_support() -> bool:
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
if compute_capability[0] >= 8:
ia_logging.debug("The CUDA device supports bfloat16")
return True
else:
ia_logging.debug("The CUDA device does not support bfloat16")
return False
else:
ia_logging.debug("CUDA is not available")
return False
def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
def wrapper(*args, **kwargs):
updated_kwargs = {**fixed_kwargs, **kwargs}

View File

@ -24,7 +24,7 @@ def check_inputs_create_mask_image(
mask: Union[np.ndarray, Image.Image],
sam_masks: List[Dict[str, Any]],
ignore_black_chk: bool = True,
) -> None:
) -> None:
"""Check create mask image inputs.
Args:
@ -70,7 +70,7 @@ def create_mask_image(
mask: Union[np.ndarray, Image.Image],
sam_masks: List[Dict[str, Any]],
ignore_black_chk: bool = True,
) -> np.ndarray:
) -> np.ndarray:
"""Create mask image.
Args:

View File

@ -16,7 +16,7 @@ if inpa_basedir not in sys.path:
from ia_file_manager import ia_file_manager # noqa: E402
from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402
from ia_logging import ia_logging # noqa: E402
from ia_sam_manager import get_sam_mask_generator # noqa: E402
from ia_sam_manager import check_bfloat16_support, get_sam_mask_generator # noqa: E402
from ia_ui_items import get_sam_model_ids # noqa: E402
@ -139,7 +139,8 @@ def generate_sam_masks(
if "sam2_" in sam_id:
device = "cuda" if torch.cuda.is_available() else "cpu"
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
torch_dtype = torch.bfloat16 if check_bfloat16_support() else torch.float16
with torch.inference_mode(), torch.autocast(device, dtype=torch_dtype):
sam_masks = sam_mask_generator.generate(input_image)
else:
sam_masks = sam_mask_generator.generate(input_image)