Enable bfloat16 for autocast only when CUDA is 8+
parent
de9f5f5c71
commit
39d647db54
|
|
@ -21,6 +21,20 @@ from segment_anything_hq import SamPredictor as SamPredictorHQ
|
||||||
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
|
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
|
||||||
|
|
||||||
|
|
||||||
|
def check_bfloat16_support() -> bool:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
|
||||||
|
if compute_capability[0] >= 8:
|
||||||
|
ia_logging.debug("The CUDA device supports bfloat16")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
ia_logging.debug("The CUDA device does not support bfloat16")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
ia_logging.debug("CUDA is not available")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
|
def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
updated_kwargs = {**fixed_kwargs, **kwargs}
|
updated_kwargs = {**fixed_kwargs, **kwargs}
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ def check_inputs_create_mask_image(
|
||||||
mask: Union[np.ndarray, Image.Image],
|
mask: Union[np.ndarray, Image.Image],
|
||||||
sam_masks: List[Dict[str, Any]],
|
sam_masks: List[Dict[str, Any]],
|
||||||
ignore_black_chk: bool = True,
|
ignore_black_chk: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check create mask image inputs.
|
"""Check create mask image inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -70,7 +70,7 @@ def create_mask_image(
|
||||||
mask: Union[np.ndarray, Image.Image],
|
mask: Union[np.ndarray, Image.Image],
|
||||||
sam_masks: List[Dict[str, Any]],
|
sam_masks: List[Dict[str, Any]],
|
||||||
ignore_black_chk: bool = True,
|
ignore_black_chk: bool = True,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Create mask image.
|
"""Create mask image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ if inpa_basedir not in sys.path:
|
||||||
from ia_file_manager import ia_file_manager # noqa: E402
|
from ia_file_manager import ia_file_manager # noqa: E402
|
||||||
from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402
|
from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402
|
||||||
from ia_logging import ia_logging # noqa: E402
|
from ia_logging import ia_logging # noqa: E402
|
||||||
from ia_sam_manager import get_sam_mask_generator # noqa: E402
|
from ia_sam_manager import check_bfloat16_support, get_sam_mask_generator # noqa: E402
|
||||||
from ia_ui_items import get_sam_model_ids # noqa: E402
|
from ia_ui_items import get_sam_model_ids # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -139,7 +139,8 @@ def generate_sam_masks(
|
||||||
|
|
||||||
if "sam2_" in sam_id:
|
if "sam2_" in sam_id:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
|
torch_dtype = torch.bfloat16 if check_bfloat16_support() else torch.float16
|
||||||
|
with torch.inference_mode(), torch.autocast(device, dtype=torch_dtype):
|
||||||
sam_masks = sam_mask_generator.generate(input_image)
|
sam_masks = sam_mask_generator.generate(input_image)
|
||||||
else:
|
else:
|
||||||
sam_masks = sam_mask_generator.generate(input_image)
|
sam_masks = sam_mask_generator.generate(input_image)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue