diff --git a/ia_sam_manager.py b/ia_sam_manager.py index d6ab1a3..0c27f7e 100644 --- a/ia_sam_manager.py +++ b/ia_sam_manager.py @@ -98,6 +98,8 @@ def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False): crop_n_layers=1, box_nms_thresh=0.7, crop_n_points_downscale_factor=2) + if platform.system() == "Darwin": + sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1)) if os.path.isfile(sam_checkpoint): sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint) diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py index 67668b2..fadcdd0 100644 --- a/sam2/automatic_mask_generator.py +++ b/sam2/automatic_mask_generator.py @@ -284,7 +284,8 @@ class SAM2AutomaticMaskGenerator: orig_h, orig_w = orig_size # Run model on this batch - points = torch.as_tensor(points, device=self.predictor.device) + # points = torch.as_tensor(points, device=self.predictor.device) + points = torch.as_tensor(points.astype(np.float32), device=self.predictor.device) in_points = self.predictor._transforms.transform_coords( points, normalize=normalize, orig_hw=im_size ) diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py index 9868429..a8642f3 100644 --- a/sam2/utils/amg.py +++ b/sam2/utils/amg.py @@ -103,7 +103,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: ), "Batched iteration must have inputs of all the same size." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): - yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: @@ -144,7 +144,7 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: idx = 0 parity = False for count in rle["counts"]: - mask[idx : idx + count] = parity + mask[idx: idx + count] = parity idx += count parity ^= True mask = mask.reshape(w, h)