diff --git a/README.md b/README.md index f767ef1..2c014c1 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/ - `2023/05/29`: [v1.4.2](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.2) You may now do SAM inference on CPU by checking "Use CPU for SAM". This is for some MAC users who are not able to do SAM inference on GPU. I discourage other users from using this feature because it is significantly slower than CUDA. - `2023/06/01`: [v1.5.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.0) You may now choose to use local GroundingDINO to bypass C++ problem. See [FAQ](#faq)-1 for more detail. - `2023/06/04`: [v1.5.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.1) `Upload Mask to ControlNet Inpainting` comes back in response to [ControlNet inpaint improvement](https://github.com/Mikubill/sd-webui-controlnet/discussions/1464). You should see a new tab beside `AutoSAM` after updating the extension. This feature will again be removed once ControlNet extension has its own uploading feature. +- `2023/06/04`: [v1.6.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.0) [SAM-HQ](https://github.com/SysCV/sam-hq) supported by [@SpenserCai](https://github.com/SpenserCai) and me. This is an "upgraded" SAM from researchers at ETH Zurich & HKUST. However, I cannot guarantee which one is better and you should make your own choice based on your own experiments. Go to [Installation](#installation) to get the link to the models. ## FAQ @@ -67,7 +68,7 @@ Choose one or more of the models below and put them to `${sd-webui}/models/sam` Three types of SAM models are available. [vit_h](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) is 2.56GB, [vit_l](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) is 1.25GB, [vit_b](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) is 375MB. I myself tested vit_h on NVIDIA 3090 Ti which is good. If you encounter VRAM problem, you should switch to smaller models. -If you want use SAM-HQ,[hq_vit_h](https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing),[hq_vit_l](https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing),[hq_vit_b](https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing) +You may also choose to use [SAM-HQ](https://github.com/SysCV/sam-hq). [hq_vit_h](https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing) is 2.4G,[hq_vit_l](https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing) is 1.2G, [hq_vit_b](https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing) is 362M. GroundingDINO packages, GroundingDINO models and ControlNet annotator models will be automatically installed the first time you use them. diff --git a/sam_hq/automatic.py b/sam_hq/automatic.py new file mode 100644 index 0000000..b563d36 --- /dev/null +++ b/sam_hq/automatic.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +from typing import List, Optional + +from segment_anything import SamAutomaticMaskGenerator +from segment_anything.utils.amg import build_all_layer_point_grids +from .predictor import SamPredictorHQ + + +class SamAutomaticMaskGeneratorHQ(SamAutomaticMaskGenerator): + def __init__( + self, + model: SamPredictorHQ, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = model + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode diff --git a/segment_anything_hq/build_sam.py b/sam_hq/build_sam_hq.py similarity index 78% rename from segment_anything_hq/build_sam.py rename to sam_hq/build_sam_hq.py index b280cf4..df28532 100644 --- a/segment_anything_hq/build_sam.py +++ b/sam_hq/build_sam_hq.py @@ -8,11 +8,14 @@ import torch from functools import partial -from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer +from .modeling.mask_decoder_hq import MaskDecoderHQ +from .modeling.image_encoder import ImageEncoderViTHQ +from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer +from segment_anything import build_sam_vit_h, build_sam_vit_l, build_sam_vit_b -def build_sam_vit_h(checkpoint=None): - return _build_sam( +def build_sam_hq_vit_h(checkpoint=None): + return _build_sam_hq( encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, @@ -21,11 +24,8 @@ def build_sam_vit_h(checkpoint=None): ) -build_sam = build_sam_vit_h - - -def build_sam_vit_l(checkpoint=None): - return _build_sam( +def build_sam_hq_vit_l(checkpoint=None): + return _build_sam_hq( encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, @@ -34,8 +34,8 @@ def build_sam_vit_l(checkpoint=None): ) -def build_sam_vit_b(checkpoint=None): - return _build_sam( +def build_sam_hq_vit_b(checkpoint=None): + return _build_sam_hq( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, @@ -45,14 +45,16 @@ def build_sam_vit_b(checkpoint=None): sam_model_registry = { - "default": build_sam_vit_h, - "vit_h": build_sam_vit_h, - "vit_l": build_sam_vit_l, - "vit_b": build_sam_vit_b, + "sam_vit_h": build_sam_vit_h, + "sam_vit_l": build_sam_vit_l, + "sam_vit_b": build_sam_vit_b, + "sam_hq_vit_h": build_sam_hq_vit_h, + "sam_hq_vit_l": build_sam_hq_vit_l, + "sam_hq_vit_b": build_sam_hq_vit_b, } -def _build_sam( +def _build_sam_hq( encoder_embed_dim, encoder_depth, encoder_num_heads, @@ -64,7 +66,7 @@ def _build_sam( vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size sam = Sam( - image_encoder=ImageEncoderViT( + image_encoder=ImageEncoderViTHQ( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, @@ -100,7 +102,7 @@ def _build_sam( pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) - # sam.eval() + sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f) diff --git a/sam_hq/modeling/image_encoder.py b/sam_hq/modeling/image_encoder.py new file mode 100644 index 0000000..8933bc0 --- /dev/null +++ b/sam_hq/modeling/image_encoder.py @@ -0,0 +1,20 @@ +import torch +from segment_anything.modeling import ImageEncoderViT + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViTHQ(ImageEncoderViT): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + interm_embeddings=[] + for blk in self.blocks: + x = blk(x) + if blk.window_size == 0: + interm_embeddings.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x, interm_embeddings \ No newline at end of file diff --git a/segment_anything_hq/modeling/mask_decoder_hq.py b/sam_hq/modeling/mask_decoder_hq.py similarity index 89% rename from segment_anything_hq/modeling/mask_decoder_hq.py rename to sam_hq/modeling/mask_decoder_hq.py index 1e365e3..244133c 100644 --- a/segment_anything_hq/modeling/mask_decoder_hq.py +++ b/sam_hq/modeling/mask_decoder_hq.py @@ -11,7 +11,7 @@ from torch.nn import functional as F from typing import List, Tuple, Type -from .common import LayerNorm2d +from segment_anything.modeling.common import LayerNorm2d class MaskDecoderHQ(nn.Module): @@ -103,8 +103,8 @@ class MaskDecoderHQ(nn.Module): sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, - hq_token_only: bool, - interm_embeddings: torch.Tensor, + hq_token_only: bool = False, + interm_embeddings: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. @@ -124,7 +124,7 @@ class MaskDecoderHQ(nn.Module): vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) - masks, iou_pred = self.predict_masks( + masks, iou_pred, masks_hq = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, @@ -133,21 +133,25 @@ class MaskDecoderHQ(nn.Module): ) # Select the correct mask or masks for output + # if multimask_output: + # # mask with highest score + # mask_slice = slice(1,self.num_mask_tokens-1) + # iou_pred = iou_pred[:, mask_slice] + # iou_pred, max_iou_idx = torch.max(iou_pred,dim=1) + # iou_pred = iou_pred.unsqueeze(1) + # masks_multi = masks[:, mask_slice, :, :] + # masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + # else: + # # single mask output, default + # mask_slice = slice(0, 1) + # iou_pred = iou_pred[:,mask_slice] + # masks_sam = masks[:,mask_slice] if multimask_output: - # mask with highest score - mask_slice = slice(1,self.num_mask_tokens-1) - iou_pred = iou_pred[:, mask_slice] - iou_pred, max_iou_idx = torch.max(iou_pred,dim=1) - iou_pred = iou_pred.unsqueeze(1) - masks_multi = masks[:, mask_slice, :, :] - masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + mask_slice = slice(1, None) else: - # singale mask output, default mask_slice = slice(0, 1) - iou_pred = iou_pred[:,mask_slice] - masks_sam = masks[:,mask_slice] - - masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)] + masks_sam = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] if hq_token_only: masks = masks_hq else: @@ -198,11 +202,11 @@ class MaskDecoderHQ(nn.Module): masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) - masks = torch.cat([masks_sam,masks_sam_hq],dim=1) + # masks = torch.cat([masks_sam,masks_sam_hq],dim=1) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) - return masks, iou_pred + return masks_sam, iou_pred, masks_sam_hq # Lightly adapted from diff --git a/sam_hq/predictor.py b/sam_hq/predictor.py new file mode 100644 index 0000000..40e3162 --- /dev/null +++ b/sam_hq/predictor.py @@ -0,0 +1,145 @@ +from typing import Optional, Tuple +import numpy as np +import torch +from segment_anything import SamPredictor +from segment_anything.modeling import Sam + + +class SamPredictorHQ(SamPredictor): + + def __init__( + self, + sam_model: Sam, + sam_is_hq: bool = False, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__(sam_model=sam_model) + self.is_hq = sam_is_hq + + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + if self.is_hq: + self.features, self.interm_features = self.model.image_encoder(input_image) + else: + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + if self.is_hq: + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=False, + interm_embeddings=self.interm_features, + ) + else: + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks diff --git a/scripts/auto.py b/scripts/auto.py index a879052..088ab13 100644 --- a/scripts/auto.py +++ b/scripts/auto.py @@ -7,13 +7,13 @@ from collections import OrderedDict import numpy as np import torch import cv2 -from segment_anything import SamAutomaticMaskGenerator +from sam_hq.automatic import SamAutomaticMaskGeneratorHQ from modules import scripts, shared from modules.paths import extensions_dir from modules.devices import torch_gc -global_sam: SamAutomaticMaskGenerator = None +global_sam: SamAutomaticMaskGeneratorHQ = None sem_seg_cache = OrderedDict() sam_annotator_dir = os.path.join(scripts.basedir(), "annotator") original_uniformer_inference_segmentor = None @@ -303,7 +303,7 @@ def register_auto_sam(sam, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode): global global_sam - global_sam = SamAutomaticMaskGenerator( + global_sam = SamAutomaticMaskGeneratorHQ( sam, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, diff --git a/scripts/sam.py b/scripts/sam.py index cb035ca..a40a2ec 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -15,8 +15,8 @@ from modules.safe import unsafe_torch_load, load from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing from modules.devices import device, torch_gc, cpu from modules.paths import models_path -from segment_anything import SamPredictor as SamPredictorBase, sam_model_registry -from segment_anything_hq import SamPredictor as SamPredictorHQ, sam_model_registry as sam_model_registry_hq +from sam_hq.predictor import SamPredictorHQ +from sam_hq.build_sam_hq import sam_model_registry from scripts.dino import dino_model_list, dino_predict_internal, show_boxes, clear_dino_cache, dino_install_issue_text from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image from scripts.process_params import SAMProcessUnit, max_cn_num @@ -30,8 +30,6 @@ sam_model_dir = sd_sam_model_dir if os.path.exists(sd_sam_model_dir) else script sam_model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt'] sam_device = device -is_hq = False - txt2img_width: gr.Slider = None txt2img_height: gr.Slider = None @@ -57,12 +55,6 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5): image[mask] = image[mask] * (1 - alpha) + 255 * color.reshape(1, 1, -1) * alpha return image.astype(np.uint8) -def SamPredictor(sam_model): - if is_hq: - return SamPredictorHQ(sam_model) - else: - return SamPredictorBase(sam_model) - def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): print("Dilation Amount: ", dilation_amt) @@ -80,17 +72,12 @@ def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): def load_sam_model(sam_checkpoint): - global is_hq - model_type = '_'.join(sam_checkpoint.split('_')[1:-1]) - sam_checkpoint = os.path.join(sam_model_dir, sam_checkpoint) + model_type = sam_checkpoint.split('.')[0] + if 'hq' not in model_type: + model_type = '_'.join(model_type.split('_')[:-1]) + sam_checkpoint_path = os.path.join(sam_model_dir, sam_checkpoint) torch.load = unsafe_torch_load - # 如果包含hq,则使用hq版本的sam - if 'hq' in sam_checkpoint: - sam = sam_model_registry_hq[model_type.replace("hq_","")](checkpoint=sam_checkpoint) - is_hq = True - else: - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) - is_hq = False + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) sam.to(device=sam_device) sam.eval() torch.load = load @@ -216,7 +203,7 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, boxes_filt = boxes_filt[valid_indices] sam = init_sam_model(sam_model_name) print(f"Running SAM Inference {image_np_rgb.shape}") - predictor = SamPredictor(sam) + predictor = SamPredictorHQ(sam, 'hq' in sam_model_name) predictor.set_image(image_np_rgb) if dino_enabled and boxes_filt.shape[0] > 1: sam_predict_status = f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts discarded" @@ -271,7 +258,7 @@ def dino_batch_process( return "Please add text prompts to generate masks" print("Start batch processing") sam = init_sam_model(batch_sam_model_name) - predictor = SamPredictor(sam) + predictor = SamPredictorHQ(sam, 'hq' in batch_sam_model_name) process_info = "" install_success = True @@ -323,7 +310,8 @@ def cnet_seg( print(f"Start semantic segmentation with processor {cnet_seg_processor}") auto_sam_output_mode = "coco_rle" if "seg" in cnet_seg_processor else "binary_mask" sam = load_sam_model(sam_model_name) - register_auto_sam(sam, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, + predictor = SamPredictorHQ(sam, 'hq' in sam_model_name) + register_auto_sam(predictor, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode) @@ -342,7 +330,8 @@ def image_layout( auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area): print("Start processing image layout") sam = load_sam_model(sam_model_name) - register_auto_sam(sam, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, + predictor = SamPredictorHQ(sam, 'hq' in sam_model_name) + register_auto_sam(predictor, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "binary_mask") @@ -362,7 +351,8 @@ def categorical_mask( auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area): print("Start processing categorical mask") sam = load_sam_model(sam_model_name) - register_auto_sam(sam, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, + predictor = SamPredictorHQ(sam, 'hq' in sam_model_name) + register_auto_sam(predictor, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle") @@ -389,7 +379,8 @@ def categorical_mask_batch( auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area): print("Start processing categorical mask in batch") sam = load_sam_model(sam_model_name) - register_auto_sam(sam, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, + predictor = SamPredictorHQ(sam, 'hq' in sam_model_name) + register_auto_sam(predictor, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle") diff --git a/segment_anything_hq/__init__.py b/segment_anything_hq/__init__.py deleted file mode 100644 index d576507..0000000 --- a/segment_anything_hq/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from .build_sam import ( - build_sam, - build_sam_vit_h, - build_sam_vit_l, - build_sam_vit_b, - sam_model_registry, -) -from .build_sam_baseline import sam_model_registry_baseline -from .predictor import SamPredictor -from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/segment_anything_hq/automatic_mask_generator.py b/segment_anything_hq/automatic_mask_generator.py deleted file mode 100644 index d5a8c96..0000000 --- a/segment_anything_hq/automatic_mask_generator.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch -from torchvision.ops.boxes import batched_nms, box_area # type: ignore - -from typing import Any, Dict, List, Optional, Tuple - -from .modeling import Sam -from .predictor import SamPredictor -from .utils.amg import ( - MaskData, - area_from_rle, - batch_iterator, - batched_mask_to_box, - box_xyxy_to_xywh, - build_all_layer_point_grids, - calculate_stability_score, - coco_encode_rle, - generate_crop_boxes, - is_box_near_crop_edge, - mask_to_rle_pytorch, - remove_small_regions, - rle_to_mask, - uncrop_boxes_xyxy, - uncrop_masks, - uncrop_points, -) - - -class SamAutomaticMaskGenerator: - def __init__( - self, - model: Sam, - points_per_side: Optional[int] = 32, - points_per_batch: int = 64, - pred_iou_thresh: float = 0.88, - stability_score_thresh: float = 0.95, - stability_score_offset: float = 1.0, - box_nms_thresh: float = 0.7, - crop_n_layers: int = 0, - crop_nms_thresh: float = 0.7, - crop_overlap_ratio: float = 512 / 1500, - crop_n_points_downscale_factor: int = 1, - point_grids: Optional[List[np.ndarray]] = None, - min_mask_region_area: int = 0, - output_mode: str = "binary_mask", - ) -> None: - """ - Using a SAM model, generates masks for the entire image. - Generates a grid of point prompts over the image, then filters - low quality and duplicate masks. The default settings are chosen - for SAM with a ViT-H backbone. - - Arguments: - model (Sam): The SAM model to use for mask prediction. - points_per_side (int or None): The number of points to be sampled - along one side of the image. The total number of points is - points_per_side**2. If None, 'point_grids' must provide explicit - point sampling. - points_per_batch (int): Sets the number of points run simultaneously - by the model. Higher numbers may be faster but use more GPU memory. - pred_iou_thresh (float): A filtering threshold in [0,1], using the - model's predicted mask quality. - stability_score_thresh (float): A filtering threshold in [0,1], using - the stability of the mask under changes to the cutoff used to binarize - the model's mask predictions. - stability_score_offset (float): The amount to shift the cutoff when - calculated the stability score. - box_nms_thresh (float): The box IoU cutoff used by non-maximal - suppression to filter duplicate masks. - crop_n_layers (int): If >0, mask prediction will be run again on - crops of the image. Sets the number of layers to run, where each - layer has 2**i_layer number of image crops. - crop_nms_thresh (float): The box IoU cutoff used by non-maximal - suppression to filter duplicate masks between different crops. - crop_overlap_ratio (float): Sets the degree to which crops overlap. - In the first crop layer, crops will overlap by this fraction of - the image length. Later layers with more crops scale down this overlap. - crop_n_points_downscale_factor (int): The number of points-per-side - sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - point_grids (list(np.ndarray) or None): A list over explicit grids - of points used for sampling, normalized to [0,1]. The nth grid in the - list is used in the nth crop layer. Exclusive with points_per_side. - min_mask_region_area (int): If >0, postprocessing will be applied - to remove disconnected regions and holes in masks with area smaller - than min_mask_region_area. Requires opencv. - output_mode (str): The form masks are returned in. Can be 'binary_mask', - 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. - For large resolutions, 'binary_mask' may consume large amounts of - memory. - """ - - assert (points_per_side is None) != ( - point_grids is None - ), "Exactly one of points_per_side or point_grid must be provided." - if points_per_side is not None: - self.point_grids = build_all_layer_point_grids( - points_per_side, - crop_n_layers, - crop_n_points_downscale_factor, - ) - elif point_grids is not None: - self.point_grids = point_grids - else: - raise ValueError("Can't have both points_per_side and point_grid be None.") - - assert output_mode in [ - "binary_mask", - "uncompressed_rle", - "coco_rle", - ], f"Unknown output_mode {output_mode}." - if output_mode == "coco_rle": - from pycocotools import mask as mask_utils # type: ignore # noqa: F401 - - if min_mask_region_area > 0: - import cv2 # type: ignore # noqa: F401 - - self.predictor = SamPredictor(model) - self.points_per_batch = points_per_batch - self.pred_iou_thresh = pred_iou_thresh - self.stability_score_thresh = stability_score_thresh - self.stability_score_offset = stability_score_offset - self.box_nms_thresh = box_nms_thresh - self.crop_n_layers = crop_n_layers - self.crop_nms_thresh = crop_nms_thresh - self.crop_overlap_ratio = crop_overlap_ratio - self.crop_n_points_downscale_factor = crop_n_points_downscale_factor - self.min_mask_region_area = min_mask_region_area - self.output_mode = output_mode - - @torch.no_grad() - def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: - """ - Generates masks for the given image. - - Arguments: - image (np.ndarray): The image to generate masks for, in HWC uint8 format. - - Returns: - list(dict(str, any)): A list over records for masks. Each record is - a dict containing the following keys: - segmentation (dict(str, any) or np.ndarray): The mask. If - output_mode='binary_mask', is an array of shape HW. Otherwise, - is a dictionary containing the RLE. - bbox (list(float)): The box around the mask, in XYWH format. - area (int): The area in pixels of the mask. - predicted_iou (float): The model's own prediction of the mask's - quality. This is filtered by the pred_iou_thresh parameter. - point_coords (list(list(float))): The point coordinates input - to the model to generate this mask. - stability_score (float): A measure of the mask's quality. This - is filtered on using the stability_score_thresh parameter. - crop_box (list(float)): The crop of the image used to generate - the mask, given in XYWH format. - """ - - # Generate masks - mask_data = self._generate_masks(image) - - # Filter small disconnected regions and holes in masks - if self.min_mask_region_area > 0: - mask_data = self.postprocess_small_regions( - mask_data, - self.min_mask_region_area, - max(self.box_nms_thresh, self.crop_nms_thresh), - ) - - # Encode masks - if self.output_mode == "coco_rle": - mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] - elif self.output_mode == "binary_mask": - mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] - else: - mask_data["segmentations"] = mask_data["rles"] - - # Write mask records - curr_anns = [] - for idx in range(len(mask_data["segmentations"])): - ann = { - "segmentation": mask_data["segmentations"][idx], - "area": area_from_rle(mask_data["rles"][idx]), - "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), - "predicted_iou": mask_data["iou_preds"][idx].item(), - "point_coords": [mask_data["points"][idx].tolist()], - "stability_score": mask_data["stability_score"][idx].item(), - "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), - } - curr_anns.append(ann) - - return curr_anns - - def _generate_masks(self, image: np.ndarray) -> MaskData: - orig_size = image.shape[:2] - crop_boxes, layer_idxs = generate_crop_boxes( - orig_size, self.crop_n_layers, self.crop_overlap_ratio - ) - - # Iterate over image crops - data = MaskData() - for crop_box, layer_idx in zip(crop_boxes, layer_idxs): - crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) - data.cat(crop_data) - - # Remove duplicate masks between crops - if len(crop_boxes) > 1: - # Prefer masks from smaller crops - scores = 1 / box_area(data["crop_boxes"]) - scores = scores.to(data["boxes"].device) - keep_by_nms = batched_nms( - data["boxes"].float(), - scores, - torch.zeros_like(data["boxes"][:, 0]), # categories - iou_threshold=self.crop_nms_thresh, - ) - data.filter(keep_by_nms) - - data.to_numpy() - return data - - def _process_crop( - self, - image: np.ndarray, - crop_box: List[int], - crop_layer_idx: int, - orig_size: Tuple[int, ...], - ) -> MaskData: - # Crop the image and calculate embeddings - x0, y0, x1, y1 = crop_box - cropped_im = image[y0:y1, x0:x1, :] - cropped_im_size = cropped_im.shape[:2] - self.predictor.set_image(cropped_im) - - # Get points for this crop - points_scale = np.array(cropped_im_size)[None, ::-1] - points_for_image = self.point_grids[crop_layer_idx] * points_scale - - # Generate masks for this crop in batches - data = MaskData() - for (points,) in batch_iterator(self.points_per_batch, points_for_image): - batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) - data.cat(batch_data) - del batch_data - self.predictor.reset_image() - - # Remove duplicates within this crop. - keep_by_nms = batched_nms( - data["boxes"].float(), - data["iou_preds"], - torch.zeros_like(data["boxes"][:, 0]), # categories - iou_threshold=self.box_nms_thresh, - ) - data.filter(keep_by_nms) - - # Return to the original image frame - data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) - data["points"] = uncrop_points(data["points"], crop_box) - data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) - - return data - - def _process_batch( - self, - points: np.ndarray, - im_size: Tuple[int, ...], - crop_box: List[int], - orig_size: Tuple[int, ...], - ) -> MaskData: - orig_h, orig_w = orig_size - - # Run model on this batch - transformed_points = self.predictor.transform.apply_coords(points, im_size) - in_points = torch.as_tensor(transformed_points, device=self.predictor.device) - in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) - masks, iou_preds, _ = self.predictor.predict_torch( - in_points[:, None, :], - in_labels[:, None], - multimask_output=True, - return_logits=True, - ) - - # Serialize predictions and store in MaskData - data = MaskData( - masks=masks.flatten(0, 1), - iou_preds=iou_preds.flatten(0, 1), - points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), - ) - del masks - - # Filter by predicted IoU - if self.pred_iou_thresh > 0.0: - keep_mask = data["iou_preds"] > self.pred_iou_thresh - data.filter(keep_mask) - - # Calculate stability score - data["stability_score"] = calculate_stability_score( - data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset - ) - if self.stability_score_thresh > 0.0: - keep_mask = data["stability_score"] >= self.stability_score_thresh - data.filter(keep_mask) - - # Threshold masks and calculate boxes - data["masks"] = data["masks"] > self.predictor.model.mask_threshold - data["boxes"] = batched_mask_to_box(data["masks"]) - - # Filter boxes that touch crop boundaries - keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) - if not torch.all(keep_mask): - data.filter(keep_mask) - - # Compress to RLE - data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) - data["rles"] = mask_to_rle_pytorch(data["masks"]) - del data["masks"] - - return data - - @staticmethod - def postprocess_small_regions( - mask_data: MaskData, min_area: int, nms_thresh: float - ) -> MaskData: - """ - Removes small disconnected regions and holes in masks, then reruns - box NMS to remove any new duplicates. - - Edits mask_data in place. - - Requires open-cv as a dependency. - """ - if len(mask_data["rles"]) == 0: - return mask_data - - # Filter small disconnected regions and holes - new_masks = [] - scores = [] - for rle in mask_data["rles"]: - mask = rle_to_mask(rle) - - mask, changed = remove_small_regions(mask, min_area, mode="holes") - unchanged = not changed - mask, changed = remove_small_regions(mask, min_area, mode="islands") - unchanged = unchanged and not changed - - new_masks.append(torch.as_tensor(mask).unsqueeze(0)) - # Give score=0 to changed masks and score=1 to unchanged masks - # so NMS will prefer ones that didn't need postprocessing - scores.append(float(unchanged)) - - # Recalculate boxes and remove any new duplicates - masks = torch.cat(new_masks, dim=0) - boxes = batched_mask_to_box(masks) - keep_by_nms = batched_nms( - boxes.float(), - torch.as_tensor(scores), - torch.zeros_like(boxes[:, 0]), # categories - iou_threshold=nms_thresh, - ) - - # Only recalculate RLEs for masks that have changed - for i_mask in keep_by_nms: - if scores[i_mask] == 0.0: - mask_torch = masks[i_mask].unsqueeze(0) - mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] - mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly - mask_data.filter(keep_by_nms) - - return mask_data diff --git a/segment_anything_hq/build_sam_baseline.py b/segment_anything_hq/build_sam_baseline.py deleted file mode 100644 index 8f14970..0000000 --- a/segment_anything_hq/build_sam_baseline.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from functools import partial - -from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer - - -def build_sam_vit_h(checkpoint=None): - return _build_sam( - encoder_embed_dim=1280, - encoder_depth=32, - encoder_num_heads=16, - encoder_global_attn_indexes=[7, 15, 23, 31], - checkpoint=checkpoint, - ) - - -build_sam = build_sam_vit_h - - -def build_sam_vit_l(checkpoint=None): - return _build_sam( - encoder_embed_dim=1024, - encoder_depth=24, - encoder_num_heads=16, - encoder_global_attn_indexes=[5, 11, 17, 23], - checkpoint=checkpoint, - ) - - -def build_sam_vit_b(checkpoint=None): - return _build_sam( - encoder_embed_dim=768, - encoder_depth=12, - encoder_num_heads=12, - encoder_global_attn_indexes=[2, 5, 8, 11], - checkpoint=checkpoint, - ) - - -sam_model_registry_baseline = { - "default": build_sam_vit_h, - "vit_h": build_sam_vit_h, - "vit_l": build_sam_vit_l, - "vit_b": build_sam_vit_b, -} - - -def _build_sam( - encoder_embed_dim, - encoder_depth, - encoder_num_heads, - encoder_global_attn_indexes, - checkpoint=None, -): - prompt_embed_dim = 256 - image_size = 1024 - vit_patch_size = 16 - image_embedding_size = image_size // vit_patch_size - sam = Sam( - image_encoder=ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - ), - prompt_encoder=PromptEncoder( - embed_dim=prompt_embed_dim, - image_embedding_size=(image_embedding_size, image_embedding_size), - input_image_size=(image_size, image_size), - mask_in_chans=16, - ), - mask_decoder=MaskDecoder( - num_multimask_outputs=3, - transformer=TwoWayTransformer( - depth=2, - embedding_dim=prompt_embed_dim, - mlp_dim=2048, - num_heads=8, - ), - transformer_dim=prompt_embed_dim, - iou_head_depth=3, - iou_head_hidden_dim=256, - ), - pixel_mean=[123.675, 116.28, 103.53], - pixel_std=[58.395, 57.12, 57.375], - ) - sam.eval() - if checkpoint is not None: - with open(checkpoint, "rb") as f: - state_dict = torch.load(f) - sam.load_state_dict(state_dict) - return sam \ No newline at end of file diff --git a/segment_anything_hq/modeling/__init__.py b/segment_anything_hq/modeling/__init__.py deleted file mode 100644 index 71172d2..0000000 --- a/segment_anything_hq/modeling/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from .sam import Sam -from .image_encoder import ImageEncoderViT -from .mask_decoder_hq import MaskDecoderHQ -from .mask_decoder import MaskDecoder -from .prompt_encoder import PromptEncoder -from .transformer import TwoWayTransformer diff --git a/segment_anything_hq/modeling/common.py b/segment_anything_hq/modeling/common.py deleted file mode 100644 index 2bf1523..0000000 --- a/segment_anything_hq/modeling/common.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from typing import Type - - -class MLPBlock(nn.Module): - def __init__( - self, - embedding_dim: int, - mlp_dim: int, - act: Type[nn.Module] = nn.GELU, - ) -> None: - super().__init__() - self.lin1 = nn.Linear(embedding_dim, mlp_dim) - self.lin2 = nn.Linear(mlp_dim, embedding_dim) - self.act = act() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.lin2(self.act(self.lin1(x))) - - -# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa -# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa -class LayerNorm2d(nn.Module): - def __init__(self, num_channels: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(num_channels)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - self.eps = eps - - def forward(self, x: torch.Tensor) -> torch.Tensor: - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x diff --git a/segment_anything_hq/modeling/image_encoder.py b/segment_anything_hq/modeling/image_encoder.py deleted file mode 100644 index 7048651..0000000 --- a/segment_anything_hq/modeling/image_encoder.py +++ /dev/null @@ -1,398 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from typing import Optional, Tuple, Type - -from .common import LayerNorm2d, MLPBlock - - -# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa -class ImageEncoderViT(nn.Module): - def __init__( - self, - img_size: int = 1024, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - mlp_ratio: float = 4.0, - out_chans: int = 256, - qkv_bias: bool = True, - norm_layer: Type[nn.Module] = nn.LayerNorm, - act_layer: Type[nn.Module] = nn.GELU, - use_abs_pos: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - global_attn_indexes: Tuple[int, ...] = (), - ) -> None: - """ - Args: - img_size (int): Input image size. - patch_size (int): Patch size. - in_chans (int): Number of input image channels. - embed_dim (int): Patch embedding dimension. - depth (int): Depth of ViT. - num_heads (int): Number of attention heads in each ViT block. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - norm_layer (nn.Module): Normalization layer. - act_layer (nn.Module): Activation layer. - use_abs_pos (bool): If True, use absolute positional embeddings. - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - window_size (int): Window size for window attention blocks. - global_attn_indexes (list): Indexes for blocks using global attention. - """ - super().__init__() - self.img_size = img_size - - self.patch_embed = PatchEmbed( - kernel_size=(patch_size, patch_size), - stride=(patch_size, patch_size), - in_chans=in_chans, - embed_dim=embed_dim, - ) - - self.pos_embed: Optional[nn.Parameter] = None - if use_abs_pos: - # Initialize absolute positional embedding with pretrain image size. - self.pos_embed = nn.Parameter( - torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) - ) - - self.blocks = nn.ModuleList() - for i in range(depth): - block = Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - norm_layer=norm_layer, - act_layer=act_layer, - use_rel_pos=use_rel_pos, - rel_pos_zero_init=rel_pos_zero_init, - window_size=window_size if i not in global_attn_indexes else 0, - input_size=(img_size // patch_size, img_size // patch_size), - ) - self.blocks.append(block) - - self.neck = nn.Sequential( - nn.Conv2d( - embed_dim, - out_chans, - kernel_size=1, - bias=False, - ), - LayerNorm2d(out_chans), - nn.Conv2d( - out_chans, - out_chans, - kernel_size=3, - padding=1, - bias=False, - ), - LayerNorm2d(out_chans), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.patch_embed(x) - if self.pos_embed is not None: - x = x + self.pos_embed - - interm_embeddings=[] - for blk in self.blocks: - x = blk(x) - if blk.window_size == 0: - interm_embeddings.append(x) - - x = self.neck(x.permute(0, 3, 1, 2)) - - return x, interm_embeddings - - -class Block(nn.Module): - """Transformer blocks with support of window attention and residual propagation blocks""" - - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - norm_layer: Type[nn.Module] = nn.LayerNorm, - act_layer: Type[nn.Module] = nn.GELU, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - input_size: Optional[Tuple[int, int]] = None, - ) -> None: - """ - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads in each ViT block. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - norm_layer (nn.Module): Normalization layer. - act_layer (nn.Module): Activation layer. - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - window_size (int): Window size for window attention blocks. If it equals 0, then - use global attention. - input_size (tuple(int, int) or None): Input resolution for calculating the relative - positional parameter size. - """ - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - use_rel_pos=use_rel_pos, - rel_pos_zero_init=rel_pos_zero_init, - input_size=input_size if window_size == 0 else (window_size, window_size), - ) - - self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) - - self.window_size = window_size - - def forward(self, x: torch.Tensor) -> torch.Tensor: - shortcut = x - x = self.norm1(x) - # Window partition - if self.window_size > 0: - H, W = x.shape[1], x.shape[2] - x, pad_hw = window_partition(x, self.window_size) - - x = self.attn(x) - # Reverse window partition - if self.window_size > 0: - x = window_unpartition(x, self.window_size, pad_hw, (H, W)) - - x = shortcut + x - x = x + self.mlp(self.norm2(x)) - - return x - - -class Attention(nn.Module): - """Multi-head Attention block with relative position embeddings.""" - - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - input_size: Optional[Tuple[int, int]] = None, - ) -> None: - """ - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - input_size (tuple(int, int) or None): Input resolution for calculating the relative - positional parameter size. - """ - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - self.use_rel_pos = use_rel_pos - if self.use_rel_pos: - assert ( - input_size is not None - ), "Input size must be provided if using relative positional encoding." - # initialize relative positional embeddings - self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, H, W, _ = x.shape - # qkv with shape (3, B, nHead, H * W, C) - qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - # q, k, v with shape (B * nHead, H * W, C) - q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) - - attn = (q * self.scale) @ k.transpose(-2, -1) - - if self.use_rel_pos: - attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) - - attn = attn.softmax(dim=-1) - x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) - x = self.proj(x) - - return x - - -def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: - """ - Partition into non-overlapping windows with padding if needed. - Args: - x (tensor): input tokens with [B, H, W, C]. - window_size (int): window size. - - Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, C]. - (Hp, Wp): padded height and width before partition - """ - B, H, W, C = x.shape - - pad_h = (window_size - H % window_size) % window_size - pad_w = (window_size - W % window_size) % window_size - if pad_h > 0 or pad_w > 0: - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) - Hp, Wp = H + pad_h, W + pad_w - - x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows, (Hp, Wp) - - -def window_unpartition( - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] -) -> torch.Tensor: - """ - Window unpartition into original sequences and removing padding. - Args: - windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. - window_size (int): window size. - pad_hw (Tuple): padded height and width (Hp, Wp). - hw (Tuple): original height and width (H, W) before padding. - - Returns: - x: unpartitioned sequences with [B, H, W, C]. - """ - Hp, Wp = pad_hw - H, W = hw - B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) - - if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() - return x - - -def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - Args: - q_size (int): size of query q. - k_size (int): size of key k. - rel_pos (Tensor): relative position embeddings (L, C). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - -def add_decomposed_rel_pos( - attn: torch.Tensor, - q: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], -) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 - Args: - attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - q_h, q_w = q_size - k_h, k_w = k_size - Rh = get_rel_pos(q_h, k_h, rel_pos_h) - Rw = get_rel_pos(q_w, k_w, rel_pos_w) - - B, _, dim = q.shape - r_q = q.reshape(B, q_h, q_w, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) - - attn = ( - attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - ).view(B, q_h * q_w, k_h * k_w) - - return attn - - -class PatchEmbed(nn.Module): - """ - Image to Patch Embedding. - """ - - def __init__( - self, - kernel_size: Tuple[int, int] = (16, 16), - stride: Tuple[int, int] = (16, 16), - padding: Tuple[int, int] = (0, 0), - in_chans: int = 3, - embed_dim: int = 768, - ) -> None: - """ - Args: - kernel_size (Tuple): kernel size of the projection layer. - stride (Tuple): stride of the projection layer. - padding (Tuple): padding size of the projection layer. - in_chans (int): Number of input image channels. - embed_dim (int): Patch embedding dimension. - """ - super().__init__() - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) - # B C H W -> B H W C - x = x.permute(0, 2, 3, 1) - return x diff --git a/segment_anything_hq/modeling/mask_decoder.py b/segment_anything_hq/modeling/mask_decoder.py deleted file mode 100644 index 242ecb7..0000000 --- a/segment_anything_hq/modeling/mask_decoder.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch import nn -from torch.nn import functional as F - -from typing import List, Tuple, Type - -from .common import LayerNorm2d - - -class MaskDecoder(nn.Module): - def __init__( - self, - *, - transformer_dim: int, - transformer: nn.Module, - num_multimask_outputs: int = 3, - activation: Type[nn.Module] = nn.GELU, - iou_head_depth: int = 3, - iou_head_hidden_dim: int = 256, - ) -> None: - """ - Predicts masks given an image and prompt embeddings, using a - transformer architecture. - - Arguments: - transformer_dim (int): the channel dimension of the transformer - transformer (nn.Module): the transformer used to predict masks - num_multimask_outputs (int): the number of masks to predict - when disambiguating masks - activation (nn.Module): the type of activation to use when - upscaling masks - iou_head_depth (int): the depth of the MLP used to predict - mask quality - iou_head_hidden_dim (int): the hidden dimension of the MLP - used to predict mask quality - """ - super().__init__() - self.transformer_dim = transformer_dim - self.transformer = transformer - - self.num_multimask_outputs = num_multimask_outputs - - self.iou_token = nn.Embedding(1, transformer_dim) - self.num_mask_tokens = num_multimask_outputs + 1 - self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) - - self.output_upscaling = nn.Sequential( - nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), - LayerNorm2d(transformer_dim // 4), - activation(), - nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), - activation(), - ) - self.output_hypernetworks_mlps = nn.ModuleList( - [ - MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) - for i in range(self.num_mask_tokens) - ] - ) - - self.iou_prediction_head = MLP( - transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth - ) - - def forward( - self, - image_embeddings: torch.Tensor, - image_pe: torch.Tensor, - sparse_prompt_embeddings: torch.Tensor, - dense_prompt_embeddings: torch.Tensor, - multimask_output: bool, - hq_token_only: bool, - interm_embeddings: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Predict masks given image and prompt embeddings. - - Arguments: - image_embeddings (torch.Tensor): the embeddings from the image encoder - image_pe (torch.Tensor): positional encoding with the shape of image_embeddings - sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes - dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs - multimask_output (bool): Whether to return multiple masks or a single - mask. - - Returns: - torch.Tensor: batched predicted masks - torch.Tensor: batched predictions of mask quality - """ - masks, iou_pred = self.predict_masks( - image_embeddings=image_embeddings, - image_pe=image_pe, - sparse_prompt_embeddings=sparse_prompt_embeddings, - dense_prompt_embeddings=dense_prompt_embeddings, - ) - - # Select the correct mask or masks for output - if multimask_output: - mask_slice = slice(1, None) - else: - mask_slice = slice(0, 1) - masks = masks[:, mask_slice, :, :] - iou_pred = iou_pred[:, mask_slice] - - # Prepare output - return masks, iou_pred - - def predict_masks( - self, - image_embeddings: torch.Tensor, - image_pe: torch.Tensor, - sparse_prompt_embeddings: torch.Tensor, - dense_prompt_embeddings: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Predicts masks. See 'forward' for more details.""" - # Concatenate output tokens - output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) - output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) - tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) - - # Expand per-image data in batch direction to be per-mask - src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) - src = src + dense_prompt_embeddings - pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) - b, c, h, w = src.shape - - # Run the transformer - hs, src = self.transformer(src, pos_src, tokens) - iou_token_out = hs[:, 0, :] - mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] - - # Upscale mask embeddings and predict masks using the mask tokens - src = src.transpose(1, 2).view(b, c, h, w) - upscaled_embedding = self.output_upscaling(src) - hyper_in_list: List[torch.Tensor] = [] - for i in range(self.num_mask_tokens): - hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) - hyper_in = torch.stack(hyper_in_list, dim=1) - b, c, h, w = upscaled_embedding.shape - masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) - - # Generate mask quality predictions - iou_pred = self.iou_prediction_head(iou_token_out) - - return masks, iou_pred - - -# Lightly adapted from -# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa -class MLP(nn.Module): - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - num_layers: int, - sigmoid_output: bool = False, - ) -> None: - super().__init__() - self.num_layers = num_layers - h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList( - nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) - ) - self.sigmoid_output = sigmoid_output - - def forward(self, x): - for i, layer in enumerate(self.layers): - x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) - if self.sigmoid_output: - x = F.sigmoid(x) - return x diff --git a/segment_anything_hq/modeling/prompt_encoder.py b/segment_anything_hq/modeling/prompt_encoder.py deleted file mode 100644 index c3143f4..0000000 --- a/segment_anything_hq/modeling/prompt_encoder.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch -from torch import nn - -from typing import Any, Optional, Tuple, Type - -from .common import LayerNorm2d - - -class PromptEncoder(nn.Module): - def __init__( - self, - embed_dim: int, - image_embedding_size: Tuple[int, int], - input_image_size: Tuple[int, int], - mask_in_chans: int, - activation: Type[nn.Module] = nn.GELU, - ) -> None: - """ - Encodes prompts for input to SAM's mask decoder. - - Arguments: - embed_dim (int): The prompts' embedding dimension - image_embedding_size (tuple(int, int)): The spatial size of the - image embedding, as (H, W). - input_image_size (int): The padded size of the image as input - to the image encoder, as (H, W). - mask_in_chans (int): The number of hidden channels used for - encoding input masks. - activation (nn.Module): The activation to use when encoding - input masks. - """ - super().__init__() - self.embed_dim = embed_dim - self.input_image_size = input_image_size - self.image_embedding_size = image_embedding_size - self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) - - self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners - point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] - self.point_embeddings = nn.ModuleList(point_embeddings) - self.not_a_point_embed = nn.Embedding(1, embed_dim) - - self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) - self.mask_downscaling = nn.Sequential( - nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), - LayerNorm2d(mask_in_chans // 4), - activation(), - nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), - LayerNorm2d(mask_in_chans), - activation(), - nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), - ) - self.no_mask_embed = nn.Embedding(1, embed_dim) - - def get_dense_pe(self) -> torch.Tensor: - """ - Returns the positional encoding used to encode point prompts, - applied to a dense set of points the shape of the image encoding. - - Returns: - torch.Tensor: Positional encoding with shape - 1x(embed_dim)x(embedding_h)x(embedding_w) - """ - return self.pe_layer(self.image_embedding_size).unsqueeze(0) - - def _embed_points( - self, - points: torch.Tensor, - labels: torch.Tensor, - pad: bool, - ) -> torch.Tensor: - """Embeds point prompts.""" - points = points + 0.5 # Shift to center of pixel - if pad: - padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) - padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) - points = torch.cat([points, padding_point], dim=1) - labels = torch.cat([labels, padding_label], dim=1) - point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - return point_embedding - - def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: - """Embeds box prompts.""" - boxes = boxes + 0.5 # Shift to center of pixel - coords = boxes.reshape(-1, 2, 2) - corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) - corner_embedding[:, 0, :] += self.point_embeddings[2].weight - corner_embedding[:, 1, :] += self.point_embeddings[3].weight - return corner_embedding - - def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: - """Embeds mask inputs.""" - mask_embedding = self.mask_downscaling(masks) - return mask_embedding - - def _get_batch_size( - self, - points: Optional[Tuple[torch.Tensor, torch.Tensor]], - boxes: Optional[torch.Tensor], - masks: Optional[torch.Tensor], - ) -> int: - """ - Gets the batch size of the output given the batch size of the input prompts. - """ - if points is not None: - return points[0].shape[0] - elif boxes is not None: - return boxes.shape[0] - elif masks is not None: - return masks.shape[0] - else: - return 1 - - def _get_device(self) -> torch.device: - return self.point_embeddings[0].weight.device - - def forward( - self, - points: Optional[Tuple[torch.Tensor, torch.Tensor]], - boxes: Optional[torch.Tensor], - masks: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Embeds different types of prompts, returning both sparse and dense - embeddings. - - Arguments: - points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates - and labels to embed. - boxes (torch.Tensor or none): boxes to embed - masks (torch.Tensor or none): masks to embed - - Returns: - torch.Tensor: sparse embeddings for the points and boxes, with shape - BxNx(embed_dim), where N is determined by the number of input points - and boxes. - torch.Tensor: dense embeddings for the masks, in the shape - Bx(embed_dim)x(embed_H)x(embed_W) - """ - bs = self._get_batch_size(points, boxes, masks) - sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) - if points is not None: - coords, labels = points - point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) - sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) - if boxes is not None: - box_embeddings = self._embed_boxes(boxes) - sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) - - if masks is not None: - dense_embeddings = self._embed_masks(masks) - else: - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( - bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] - ) - - return sparse_embeddings, dense_embeddings - - -class PositionEmbeddingRandom(nn.Module): - """ - Positional encoding using random spatial frequencies. - """ - - def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: - super().__init__() - if scale is None or scale <= 0.0: - scale = 1.0 - self.register_buffer( - "positional_encoding_gaussian_matrix", - scale * torch.randn((2, num_pos_feats)), - ) - - def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: - """Positionally encode points that are normalized to [0,1].""" - # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape - coords = 2 * coords - 1 - coords = coords @ self.positional_encoding_gaussian_matrix - coords = 2 * np.pi * coords - # outputs d_1 x ... x d_n x C shape - return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - - def forward(self, size: Tuple[int, int]) -> torch.Tensor: - """Generate positional encoding for a grid of the specified size.""" - h, w = size - device: Any = self.positional_encoding_gaussian_matrix.device - grid = torch.ones((h, w), device=device, dtype=torch.float32) - y_embed = grid.cumsum(dim=0) - 0.5 - x_embed = grid.cumsum(dim=1) - 0.5 - y_embed = y_embed / h - x_embed = x_embed / w - - pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) - return pe.permute(2, 0, 1) # C x H x W - - def forward_with_coords( - self, coords_input: torch.Tensor, image_size: Tuple[int, int] - ) -> torch.Tensor: - """Positionally encode points that are not normalized to [0,1].""" - coords = coords_input.clone() - coords[:, :, 0] = coords[:, :, 0] / image_size[1] - coords[:, :, 1] = coords[:, :, 1] / image_size[0] - return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/segment_anything_hq/modeling/sam.py b/segment_anything_hq/modeling/sam.py deleted file mode 100644 index b928dfd..0000000 --- a/segment_anything_hq/modeling/sam.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch import nn -from torch.nn import functional as F - -from typing import Any, Dict, List, Tuple - -from .image_encoder import ImageEncoderViT -from .mask_decoder import MaskDecoder -from .prompt_encoder import PromptEncoder - - -class Sam(nn.Module): - mask_threshold: float = 0.0 - image_format: str = "RGB" - - def __init__( - self, - image_encoder: ImageEncoderViT, - prompt_encoder: PromptEncoder, - mask_decoder: MaskDecoder, - pixel_mean: List[float] = [123.675, 116.28, 103.53], - pixel_std: List[float] = [58.395, 57.12, 57.375], - ) -> None: - """ - SAM predicts object masks from an image and input prompts. - - Arguments: - image_encoder (ImageEncoderViT): The backbone used to encode the - image into image embeddings that allow for efficient mask prediction. - prompt_encoder (PromptEncoder): Encodes various types of input prompts. - mask_decoder (MaskDecoder): Predicts masks from the image embeddings - and encoded prompts. - pixel_mean (list(float)): Mean values for normalizing pixels in the input image. - pixel_std (list(float)): Std values for normalizing pixels in the input image. - """ - super().__init__() - self.image_encoder = image_encoder - self.prompt_encoder = prompt_encoder - self.mask_decoder = mask_decoder - self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) - self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) - - @property - def device(self) -> Any: - return self.pixel_mean.device - - def forward( - self, - batched_input: List[Dict[str, Any]], - multimask_output: bool, - hq_token_only: bool =False, - ) -> List[Dict[str, torch.Tensor]]: - """ - Predicts masks end-to-end from provided images and prompts. - If prompts are not known in advance, using SamPredictor is - recommended over calling the model directly. - - Arguments: - batched_input (list(dict)): A list over input images, each a - dictionary with the following keys. A prompt key can be - excluded if it is not present. - 'image': The image as a torch tensor in 3xHxW format, - already transformed for input to the model. - 'original_size': (tuple(int, int)) The original size of - the image before transformation, as (H, W). - 'point_coords': (torch.Tensor) Batched point prompts for - this image, with shape BxNx2. Already transformed to the - input frame of the model. - 'point_labels': (torch.Tensor) Batched labels for point prompts, - with shape BxN. - 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. - Already transformed to the input frame of the model. - 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, - in the form Bx1xHxW. - multimask_output (bool): Whether the model should predict multiple - disambiguating masks, or return a single mask. - - Returns: - (list(dict)): A list over input images, where each element is - as dictionary with the following keys. - 'masks': (torch.Tensor) Batched binary mask predictions, - with shape BxCxHxW, where B is the number of input prompts, - C is determined by multimask_output, and (H, W) is the - original size of the image. - 'iou_predictions': (torch.Tensor) The model's predictions - of mask quality, in shape BxC. - 'low_res_logits': (torch.Tensor) Low resolution logits with - shape BxCxHxW, where H=W=256. Can be passed as mask input - to subsequent iterations of prediction. - """ - input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) - image_embeddings, interm_embeddings = self.image_encoder(input_images) - interm_embeddings = interm_embeddings[0] # early layer - - outputs = [] - for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): - if "point_coords" in image_record: - points = (image_record["point_coords"], image_record["point_labels"]) - else: - points = None - sparse_embeddings, dense_embeddings = self.prompt_encoder( - points=points, - boxes=image_record.get("boxes", None), - masks=image_record.get("mask_inputs", None), - ) - low_res_masks, iou_predictions = self.mask_decoder( - image_embeddings=curr_embedding.unsqueeze(0), - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - hq_token_only=hq_token_only, - interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), - ) - masks = self.postprocess_masks( - low_res_masks, - input_size=image_record["image"].shape[-2:], - original_size=image_record["original_size"], - ) - masks = masks > self.mask_threshold - outputs.append( - { - "masks": masks, - "iou_predictions": iou_predictions, - "low_res_logits": low_res_masks, - } - ) - return outputs - - def postprocess_masks( - self, - masks: torch.Tensor, - input_size: Tuple[int, ...], - original_size: Tuple[int, ...], - ) -> torch.Tensor: - """ - Remove padding and upscale masks to the original image size. - - Arguments: - masks (torch.Tensor): Batched masks from the mask_decoder, - in BxCxHxW format. - input_size (tuple(int, int)): The size of the image input to the - model, in (H, W) format. Used to remove padding. - original_size (tuple(int, int)): The original size of the image - before resizing for input to the model, in (H, W) format. - - Returns: - (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) - is given by original_size. - """ - masks = F.interpolate( - masks, - (self.image_encoder.img_size, self.image_encoder.img_size), - mode="bilinear", - align_corners=False, - ) - masks = masks[..., : input_size[0], : input_size[1]] - masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) - return masks - - def preprocess(self, x: torch.Tensor) -> torch.Tensor: - """Normalize pixel values and pad to a square input.""" - # Normalize colors - x = (x - self.pixel_mean) / self.pixel_std - - # Pad - h, w = x.shape[-2:] - padh = self.image_encoder.img_size - h - padw = self.image_encoder.img_size - w - x = F.pad(x, (0, padw, 0, padh)) - return x diff --git a/segment_anything_hq/modeling/transformer.py b/segment_anything_hq/modeling/transformer.py deleted file mode 100644 index 28fafea..0000000 --- a/segment_anything_hq/modeling/transformer.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch import Tensor, nn - -import math -from typing import Tuple, Type - -from .common import MLPBlock - - -class TwoWayTransformer(nn.Module): - def __init__( - self, - depth: int, - embedding_dim: int, - num_heads: int, - mlp_dim: int, - activation: Type[nn.Module] = nn.ReLU, - attention_downsample_rate: int = 2, - ) -> None: - """ - A transformer decoder that attends to an input image using - queries whose positional embedding is supplied. - - Args: - depth (int): number of layers in the transformer - embedding_dim (int): the channel dimension for the input embeddings - num_heads (int): the number of heads for multihead attention. Must - divide embedding_dim - mlp_dim (int): the channel dimension internal to the MLP block - activation (nn.Module): the activation to use in the MLP block - """ - super().__init__() - self.depth = depth - self.embedding_dim = embedding_dim - self.num_heads = num_heads - self.mlp_dim = mlp_dim - self.layers = nn.ModuleList() - - for i in range(depth): - self.layers.append( - TwoWayAttentionBlock( - embedding_dim=embedding_dim, - num_heads=num_heads, - mlp_dim=mlp_dim, - activation=activation, - attention_downsample_rate=attention_downsample_rate, - skip_first_layer_pe=(i == 0), - ) - ) - - self.final_attn_token_to_image = Attention( - embedding_dim, num_heads, downsample_rate=attention_downsample_rate - ) - self.norm_final_attn = nn.LayerNorm(embedding_dim) - - def forward( - self, - image_embedding: Tensor, - image_pe: Tensor, - point_embedding: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - image_embedding (torch.Tensor): image to attend to. Should be shape - B x embedding_dim x h x w for any h and w. - image_pe (torch.Tensor): the positional encoding to add to the image. Must - have the same shape as image_embedding. - point_embedding (torch.Tensor): the embedding to add to the query points. - Must have shape B x N_points x embedding_dim for any N_points. - - Returns: - torch.Tensor: the processed point_embedding - torch.Tensor: the processed image_embedding - """ - # BxCxHxW -> BxHWxC == B x N_image_tokens x C - bs, c, h, w = image_embedding.shape - image_embedding = image_embedding.flatten(2).permute(0, 2, 1) - image_pe = image_pe.flatten(2).permute(0, 2, 1) - - # Prepare queries - queries = point_embedding - keys = image_embedding - - # Apply transformer blocks and final layernorm - for layer in self.layers: - queries, keys = layer( - queries=queries, - keys=keys, - query_pe=point_embedding, - key_pe=image_pe, - ) - - # Apply the final attention layer from the points to the image - q = queries + point_embedding - k = keys + image_pe - attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) - queries = queries + attn_out - queries = self.norm_final_attn(queries) - - return queries, keys - - -class TwoWayAttentionBlock(nn.Module): - def __init__( - self, - embedding_dim: int, - num_heads: int, - mlp_dim: int = 2048, - activation: Type[nn.Module] = nn.ReLU, - attention_downsample_rate: int = 2, - skip_first_layer_pe: bool = False, - ) -> None: - """ - A transformer block with four layers: (1) self-attention of sparse - inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp - block on sparse inputs, and (4) cross attention of dense inputs to sparse - inputs. - - Arguments: - embedding_dim (int): the channel dimension of the embeddings - num_heads (int): the number of heads in the attention layers - mlp_dim (int): the hidden dimension of the mlp block - activation (nn.Module): the activation of the mlp block - skip_first_layer_pe (bool): skip the PE on the first layer - """ - super().__init__() - self.self_attn = Attention(embedding_dim, num_heads) - self.norm1 = nn.LayerNorm(embedding_dim) - - self.cross_attn_token_to_image = Attention( - embedding_dim, num_heads, downsample_rate=attention_downsample_rate - ) - self.norm2 = nn.LayerNorm(embedding_dim) - - self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) - self.norm3 = nn.LayerNorm(embedding_dim) - - self.norm4 = nn.LayerNorm(embedding_dim) - self.cross_attn_image_to_token = Attention( - embedding_dim, num_heads, downsample_rate=attention_downsample_rate - ) - - self.skip_first_layer_pe = skip_first_layer_pe - - def forward( - self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor - ) -> Tuple[Tensor, Tensor]: - # Self attention block - if self.skip_first_layer_pe: - queries = self.self_attn(q=queries, k=queries, v=queries) - else: - q = queries + query_pe - attn_out = self.self_attn(q=q, k=q, v=queries) - queries = queries + attn_out - queries = self.norm1(queries) - - # Cross attention block, tokens attending to image embedding - q = queries + query_pe - k = keys + key_pe - attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) - queries = queries + attn_out - queries = self.norm2(queries) - - # MLP block - mlp_out = self.mlp(queries) - queries = queries + mlp_out - queries = self.norm3(queries) - - # Cross attention block, image embedding attending to tokens - q = queries + query_pe - k = keys + key_pe - attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) - keys = keys + attn_out - keys = self.norm4(keys) - - return queries, keys - - -class Attention(nn.Module): - """ - An attention layer that allows for downscaling the size of the embedding - after projection to queries, keys, and values. - """ - - def __init__( - self, - embedding_dim: int, - num_heads: int, - downsample_rate: int = 1, - ) -> None: - super().__init__() - self.embedding_dim = embedding_dim - self.internal_dim = embedding_dim // downsample_rate - self.num_heads = num_heads - assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." - - self.q_proj = nn.Linear(embedding_dim, self.internal_dim) - self.k_proj = nn.Linear(embedding_dim, self.internal_dim) - self.v_proj = nn.Linear(embedding_dim, self.internal_dim) - self.out_proj = nn.Linear(self.internal_dim, embedding_dim) - - def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: - b, n, c = x.shape - x = x.reshape(b, n, num_heads, c // num_heads) - return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head - - def _recombine_heads(self, x: Tensor) -> Tensor: - b, n_heads, n_tokens, c_per_head = x.shape - x = x.transpose(1, 2) - return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C - - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: - # Input projections - q = self.q_proj(q) - k = self.k_proj(k) - v = self.v_proj(v) - - # Separate into heads - q = self._separate_heads(q, self.num_heads) - k = self._separate_heads(k, self.num_heads) - v = self._separate_heads(v, self.num_heads) - - # Attention - _, _, _, c_per_head = q.shape - attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens - attn = attn / math.sqrt(c_per_head) - attn = torch.softmax(attn, dim=-1) - - # Get output - out = attn @ v - out = self._recombine_heads(out) - out = self.out_proj(out) - - return out diff --git a/segment_anything_hq/predictor.py b/segment_anything_hq/predictor.py deleted file mode 100644 index 31458fb..0000000 --- a/segment_anything_hq/predictor.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch - -from .modeling import Sam - -from typing import Optional, Tuple - -from .utils.transforms import ResizeLongestSide - - -class SamPredictor: - def __init__( - self, - sam_model: Sam, - ) -> None: - """ - Uses SAM to calculate the image embedding for an image, and then - allow repeated, efficient mask prediction given prompts. - - Arguments: - sam_model (Sam): The model to use for mask prediction. - """ - super().__init__() - self.model = sam_model - self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) - self.reset_image() - - def set_image( - self, - image: np.ndarray, - image_format: str = "RGB", - ) -> None: - """ - Calculates the image embeddings for the provided image, allowing - masks to be predicted with the 'predict' method. - - Arguments: - image (np.ndarray): The image for calculating masks. Expects an - image in HWC uint8 format, with pixel values in [0, 255]. - image_format (str): The color format of the image, in ['RGB', 'BGR']. - """ - assert image_format in [ - "RGB", - "BGR", - ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." - # import pdb;pdb.set_trace() - if image_format != self.model.image_format: - image = image[..., ::-1] - - # Transform the image to the form expected by the model - # import pdb;pdb.set_trace() - input_image = self.transform.apply_image(image) - input_image_torch = torch.as_tensor(input_image, device=self.device) - input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] - - self.set_torch_image(input_image_torch, image.shape[:2]) - - @torch.no_grad() - def set_torch_image( - self, - transformed_image: torch.Tensor, - original_image_size: Tuple[int, ...], - ) -> None: - """ - Calculates the image embeddings for the provided image, allowing - masks to be predicted with the 'predict' method. Expects the input - image to be already transformed to the format expected by the model. - - Arguments: - transformed_image (torch.Tensor): The input image, with shape - 1x3xHxW, which has been transformed with ResizeLongestSide. - original_image_size (tuple(int, int)): The size of the image - before transformation, in (H, W) format. - """ - assert ( - len(transformed_image.shape) == 4 - and transformed_image.shape[1] == 3 - and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size - ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." - self.reset_image() - - self.original_size = original_image_size - self.input_size = tuple(transformed_image.shape[-2:]) - input_image = self.model.preprocess(transformed_image) - self.features, self.interm_features = self.model.image_encoder(input_image) - self.is_image_set = True - - def predict( - self, - point_coords: Optional[np.ndarray] = None, - point_labels: Optional[np.ndarray] = None, - box: Optional[np.ndarray] = None, - mask_input: Optional[np.ndarray] = None, - multimask_output: bool = True, - return_logits: bool = False, - hq_token_only: bool =False, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Predict masks for the given input prompts, using the currently set image. - - Arguments: - point_coords (np.ndarray or None): A Nx2 array of point prompts to the - model. Each point is in (X,Y) in pixels. - point_labels (np.ndarray or None): A length N array of labels for the - point prompts. 1 indicates a foreground point and 0 indicates a - background point. - box (np.ndarray or None): A length 4 array given a box prompt to the - model, in XYXY format. - mask_input (np.ndarray): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form 1xHxW, where - for SAM, H=W=256. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. - return_logits (bool): If true, returns un-thresholded masks logits - instead of a binary mask. - - Returns: - (np.ndarray): The output masks in CxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (np.ndarray): An array of length C containing the model's - predictions for the quality of each mask. - (np.ndarray): An array of shape CxHxW, where C is the number - of masks and H=W=256. These low resolution logits can be passed to - a subsequent iteration as mask input. - """ - if not self.is_image_set: - raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") - - # Transform input prompts - coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None - if point_coords is not None: - assert ( - point_labels is not None - ), "point_labels must be supplied if point_coords is supplied." - point_coords = self.transform.apply_coords(point_coords, self.original_size) - coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) - labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) - coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] - if box is not None: - box = self.transform.apply_boxes(box, self.original_size) - box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) - box_torch = box_torch[None, :] - if mask_input is not None: - mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) - mask_input_torch = mask_input_torch[None, :, :, :] - - masks, iou_predictions, low_res_masks = self.predict_torch( - coords_torch, - labels_torch, - box_torch, - mask_input_torch, - multimask_output, - return_logits=return_logits, - hq_token_only=hq_token_only, - ) - - masks_np = masks[0].detach().cpu().numpy() - iou_predictions_np = iou_predictions[0].detach().cpu().numpy() - low_res_masks_np = low_res_masks[0].detach().cpu().numpy() - return masks_np, iou_predictions_np, low_res_masks_np - - @torch.no_grad() - def predict_torch( - self, - point_coords: Optional[torch.Tensor], - point_labels: Optional[torch.Tensor], - boxes: Optional[torch.Tensor] = None, - mask_input: Optional[torch.Tensor] = None, - multimask_output: bool = True, - return_logits: bool = False, - hq_token_only: bool =False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Predict masks for the given input prompts, using the currently set image. - Input prompts are batched torch tensors and are expected to already be - transformed to the input frame using ResizeLongestSide. - - Arguments: - point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the - model. Each point is in (X,Y) in pixels. - point_labels (torch.Tensor or None): A BxN array of labels for the - point prompts. 1 indicates a foreground point and 0 indicates a - background point. - boxes (np.ndarray or None): A Bx4 array given a box prompt to the - model, in XYXY format. - mask_input (np.ndarray): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form Bx1xHxW, where - for SAM, H=W=256. Masks returned by a previous iteration of the - predict method do not need further transformation. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. - return_logits (bool): If true, returns un-thresholded masks logits - instead of a binary mask. - - Returns: - (torch.Tensor): The output masks in BxCxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (torch.Tensor): An array of shape BxC containing the model's - predictions for the quality of each mask. - (torch.Tensor): An array of shape BxCxHxW, where C is the number - of masks and H=W=256. These low res logits can be passed to - a subsequent iteration as mask input. - """ - if not self.is_image_set: - raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") - - if point_coords is not None: - points = (point_coords, point_labels) - else: - points = None - - # Embed prompts - sparse_embeddings, dense_embeddings = self.model.prompt_encoder( - points=points, - boxes=boxes, - masks=mask_input, - ) - - # Predict masks - low_res_masks, iou_predictions = self.model.mask_decoder( - image_embeddings=self.features, - image_pe=self.model.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - hq_token_only=hq_token_only, - interm_embeddings=self.interm_features, - ) - - # Upscale the masks to the original image resolution - masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) - - if not return_logits: - masks = masks > self.model.mask_threshold - - return masks, iou_predictions, low_res_masks - - def get_image_embedding(self) -> torch.Tensor: - """ - Returns the image embeddings for the currently set image, with - shape 1xCxHxW, where C is the embedding dimension and (H,W) are - the embedding spatial dimension of SAM (typically C=256, H=W=64). - """ - if not self.is_image_set: - raise RuntimeError( - "An image must be set with .set_image(...) to generate an embedding." - ) - assert self.features is not None, "Features must exist if an image has been set." - return self.features - - @property - def device(self) -> torch.device: - return self.model.device - - def reset_image(self) -> None: - """Resets the currently set image.""" - self.is_image_set = False - self.features = None - self.orig_h = None - self.orig_w = None - self.input_h = None - self.input_w = None diff --git a/segment_anything_hq/utils/__init__.py b/segment_anything_hq/utils/__init__.py deleted file mode 100644 index 5277f46..0000000 --- a/segment_anything_hq/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/segment_anything_hq/utils/amg.py b/segment_anything_hq/utils/amg.py deleted file mode 100644 index be06407..0000000 --- a/segment_anything_hq/utils/amg.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch - -import math -from copy import deepcopy -from itertools import product -from typing import Any, Dict, Generator, ItemsView, List, Tuple - - -class MaskData: - """ - A structure for storing masks and their related data in batched format. - Implements basic filtering and concatenation. - """ - - def __init__(self, **kwargs) -> None: - for v in kwargs.values(): - assert isinstance( - v, (list, np.ndarray, torch.Tensor) - ), "MaskData only supports list, numpy arrays, and torch tensors." - self._stats = dict(**kwargs) - - def __setitem__(self, key: str, item: Any) -> None: - assert isinstance( - item, (list, np.ndarray, torch.Tensor) - ), "MaskData only supports list, numpy arrays, and torch tensors." - self._stats[key] = item - - def __delitem__(self, key: str) -> None: - del self._stats[key] - - def __getitem__(self, key: str) -> Any: - return self._stats[key] - - def items(self) -> ItemsView[str, Any]: - return self._stats.items() - - def filter(self, keep: torch.Tensor) -> None: - for k, v in self._stats.items(): - if v is None: - self._stats[k] = None - elif isinstance(v, torch.Tensor): - self._stats[k] = v[torch.as_tensor(keep, device=v.device)] - elif isinstance(v, np.ndarray): - self._stats[k] = v[keep.detach().cpu().numpy()] - elif isinstance(v, list) and keep.dtype == torch.bool: - self._stats[k] = [a for i, a in enumerate(v) if keep[i]] - elif isinstance(v, list): - self._stats[k] = [v[i] for i in keep] - else: - raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") - - def cat(self, new_stats: "MaskData") -> None: - for k, v in new_stats.items(): - if k not in self._stats or self._stats[k] is None: - self._stats[k] = deepcopy(v) - elif isinstance(v, torch.Tensor): - self._stats[k] = torch.cat([self._stats[k], v], dim=0) - elif isinstance(v, np.ndarray): - self._stats[k] = np.concatenate([self._stats[k], v], axis=0) - elif isinstance(v, list): - self._stats[k] = self._stats[k] + deepcopy(v) - else: - raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") - - def to_numpy(self) -> None: - for k, v in self._stats.items(): - if isinstance(v, torch.Tensor): - self._stats[k] = v.detach().cpu().numpy() - - -def is_box_near_crop_edge( - boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 -) -> torch.Tensor: - """Filter masks at the edge of a crop, but not at the edge of the original image.""" - crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) - orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) - boxes = uncrop_boxes_xyxy(boxes, crop_box).float() - near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) - near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) - near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) - return torch.any(near_crop_edge, dim=1) - - -def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: - box_xywh = deepcopy(box_xyxy) - box_xywh[2] = box_xywh[2] - box_xywh[0] - box_xywh[3] = box_xywh[3] - box_xywh[1] - return box_xywh - - -def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: - assert len(args) > 0 and all( - len(a) == len(args[0]) for a in args - ), "Batched iteration must have inputs of all the same size." - n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) - for b in range(n_batches): - yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] - - -def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: - """ - Encodes masks to an uncompressed RLE, in the format expected by - pycoco tools. - """ - # Put in fortran order and flatten h,w - b, h, w = tensor.shape - tensor = tensor.permute(0, 2, 1).flatten(1) - - # Compute change indices - diff = tensor[:, 1:] ^ tensor[:, :-1] - change_indices = diff.nonzero() - - # Encode run length - out = [] - for i in range(b): - cur_idxs = change_indices[change_indices[:, 0] == i, 1] - cur_idxs = torch.cat( - [ - torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), - cur_idxs + 1, - torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), - ] - ) - btw_idxs = cur_idxs[1:] - cur_idxs[:-1] - counts = [] if tensor[i, 0] == 0 else [0] - counts.extend(btw_idxs.detach().cpu().tolist()) - out.append({"size": [h, w], "counts": counts}) - return out - - -def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: - """Compute a binary mask from an uncompressed RLE.""" - h, w = rle["size"] - mask = np.empty(h * w, dtype=bool) - idx = 0 - parity = False - for count in rle["counts"]: - mask[idx : idx + count] = parity - idx += count - parity ^= True - mask = mask.reshape(w, h) - return mask.transpose() # Put in C order - - -def area_from_rle(rle: Dict[str, Any]) -> int: - return sum(rle["counts"][1::2]) - - -def calculate_stability_score( - masks: torch.Tensor, mask_threshold: float, threshold_offset: float -) -> torch.Tensor: - """ - Computes the stability score for a batch of masks. The stability - score is the IoU between the binary masks obtained by thresholding - the predicted mask logits at high and low values. - """ - # One mask is always contained inside the other. - # Save memory by preventing unnecessary cast to torch.int64 - intersections = ( - (masks > (mask_threshold + threshold_offset)) - .sum(-1, dtype=torch.int16) - .sum(-1, dtype=torch.int32) - ) - unions = ( - (masks > (mask_threshold - threshold_offset)) - .sum(-1, dtype=torch.int16) - .sum(-1, dtype=torch.int32) - ) - return intersections / unions - - -def build_point_grid(n_per_side: int) -> np.ndarray: - """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" - offset = 1 / (2 * n_per_side) - points_one_side = np.linspace(offset, 1 - offset, n_per_side) - points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) - points_y = np.tile(points_one_side[:, None], (1, n_per_side)) - points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) - return points - - -def build_all_layer_point_grids( - n_per_side: int, n_layers: int, scale_per_layer: int -) -> List[np.ndarray]: - """Generates point grids for all crop layers.""" - points_by_layer = [] - for i in range(n_layers + 1): - n_points = int(n_per_side / (scale_per_layer**i)) - points_by_layer.append(build_point_grid(n_points)) - return points_by_layer - - -def generate_crop_boxes( - im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float -) -> Tuple[List[List[int]], List[int]]: - """ - Generates a list of crop boxes of different sizes. Each layer - has (2**i)**2 boxes for the ith layer. - """ - crop_boxes, layer_idxs = [], [] - im_h, im_w = im_size - short_side = min(im_h, im_w) - - # Original image - crop_boxes.append([0, 0, im_w, im_h]) - layer_idxs.append(0) - - def crop_len(orig_len, n_crops, overlap): - return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) - - for i_layer in range(n_layers): - n_crops_per_side = 2 ** (i_layer + 1) - overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) - - crop_w = crop_len(im_w, n_crops_per_side, overlap) - crop_h = crop_len(im_h, n_crops_per_side, overlap) - - crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] - crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] - - # Crops in XYWH format - for x0, y0 in product(crop_box_x0, crop_box_y0): - box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] - crop_boxes.append(box) - layer_idxs.append(i_layer + 1) - - return crop_boxes, layer_idxs - - -def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: - x0, y0, _, _ = crop_box - offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) - # Check if boxes has a channel dimension - if len(boxes.shape) == 3: - offset = offset.unsqueeze(1) - return boxes + offset - - -def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: - x0, y0, _, _ = crop_box - offset = torch.tensor([[x0, y0]], device=points.device) - # Check if points has a channel dimension - if len(points.shape) == 3: - offset = offset.unsqueeze(1) - return points + offset - - -def uncrop_masks( - masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int -) -> torch.Tensor: - x0, y0, x1, y1 = crop_box - if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: - return masks - # Coordinate transform masks - pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) - pad = (x0, pad_x - x0, y0, pad_y - y0) - return torch.nn.functional.pad(masks, pad, value=0) - - -def remove_small_regions( - mask: np.ndarray, area_thresh: float, mode: str -) -> Tuple[np.ndarray, bool]: - """ - Removes small disconnected regions and holes in a mask. Returns the - mask and an indicator of if the mask has been modified. - """ - import cv2 # type: ignore - - assert mode in ["holes", "islands"] - correct_holes = mode == "holes" - working_mask = (correct_holes ^ mask).astype(np.uint8) - n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) - sizes = stats[:, -1][1:] # Row 0 is background label - small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] - if len(small_regions) == 0: - return mask, False - fill_labels = [0] + small_regions - if not correct_holes: - fill_labels = [i for i in range(n_labels) if i not in fill_labels] - # If every region is below threshold, keep largest - if len(fill_labels) == 0: - fill_labels = [int(np.argmax(sizes)) + 1] - mask = np.isin(regions, fill_labels) - return mask, True - - -def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: - from pycocotools import mask as mask_utils # type: ignore - - h, w = uncompressed_rle["size"] - rle = mask_utils.frPyObjects(uncompressed_rle, h, w) - rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json - return rle - - -def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: - """ - Calculates boxes in XYXY format around masks. Return [0,0,0,0] for - an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. - """ - # torch.max below raises an error on empty inputs, just skip in this case - if torch.numel(masks) == 0: - return torch.zeros(*masks.shape[:-2], 4, device=masks.device) - - # Normalize shape to CxHxW - shape = masks.shape - h, w = shape[-2:] - if len(shape) > 2: - masks = masks.flatten(0, -3) - else: - masks = masks.unsqueeze(0) - - # Get top and bottom edges - in_height, _ = torch.max(masks, dim=-1) - in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] - bottom_edges, _ = torch.max(in_height_coords, dim=-1) - in_height_coords = in_height_coords + h * (~in_height) - top_edges, _ = torch.min(in_height_coords, dim=-1) - - # Get left and right edges - in_width, _ = torch.max(masks, dim=-2) - in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] - right_edges, _ = torch.max(in_width_coords, dim=-1) - in_width_coords = in_width_coords + w * (~in_width) - left_edges, _ = torch.min(in_width_coords, dim=-1) - - # If the mask is empty the right edge will be to the left of the left edge. - # Replace these boxes with [0, 0, 0, 0] - empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) - out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) - out = out * (~empty_filter).unsqueeze(-1) - - # Return to original shape - if len(shape) > 2: - out = out.reshape(*shape[:-2], 4) - else: - out = out[0] - - return out diff --git a/segment_anything_hq/utils/onnx.py b/segment_anything_hq/utils/onnx.py deleted file mode 100644 index 3196bdf..0000000 --- a/segment_anything_hq/utils/onnx.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from torch.nn import functional as F - -from typing import Tuple - -from ..modeling import Sam -from .amg import calculate_stability_score - - -class SamOnnxModel(nn.Module): - """ - This model should not be called directly, but is used in ONNX export. - It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, - with some functions modified to enable model tracing. Also supports extra - options controlling what information. See the ONNX export script for details. - """ - - def __init__( - self, - model: Sam, - return_single_mask: bool, - use_stability_score: bool = False, - return_extra_metrics: bool = False, - ) -> None: - super().__init__() - self.mask_decoder = model.mask_decoder - self.model = model - self.img_size = model.image_encoder.img_size - self.return_single_mask = return_single_mask - self.use_stability_score = use_stability_score - self.stability_score_offset = 1.0 - self.return_extra_metrics = return_extra_metrics - - @staticmethod - def resize_longest_image_size( - input_image_size: torch.Tensor, longest_side: int - ) -> torch.Tensor: - input_image_size = input_image_size.to(torch.float32) - scale = longest_side / torch.max(input_image_size) - transformed_size = scale * input_image_size - transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) - return transformed_size - - def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: - point_coords = point_coords + 0.5 - point_coords = point_coords / self.img_size - point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) - point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) - - point_embedding = point_embedding * (point_labels != -1) - point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( - point_labels == -1 - ) - - for i in range(self.model.prompt_encoder.num_point_embeddings): - point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ - i - ].weight * (point_labels == i) - - return point_embedding - - def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: - mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) - mask_embedding = mask_embedding + ( - 1 - has_mask_input - ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) - return mask_embedding - - def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: - masks = F.interpolate( - masks, - size=(self.img_size, self.img_size), - mode="bilinear", - align_corners=False, - ) - - prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) - masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore - - orig_im_size = orig_im_size.to(torch.int64) - h, w = orig_im_size[0], orig_im_size[1] - masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) - return masks - - def select_masks( - self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Determine if we should return the multiclick mask or not from the number of points. - # The reweighting is used to avoid control flow. - score_reweight = torch.tensor( - [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] - ).to(iou_preds.device) - score = iou_preds + (num_points - 2.5) * score_reweight - best_idx = torch.argmax(score, dim=1) - masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) - iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) - - return masks, iou_preds - - @torch.no_grad() - def forward( - self, - image_embeddings: torch.Tensor, - point_coords: torch.Tensor, - point_labels: torch.Tensor, - mask_input: torch.Tensor, - has_mask_input: torch.Tensor, - orig_im_size: torch.Tensor, - ): - sparse_embedding = self._embed_points(point_coords, point_labels) - dense_embedding = self._embed_masks(mask_input, has_mask_input) - - masks, scores = self.model.mask_decoder.predict_masks( - image_embeddings=image_embeddings, - image_pe=self.model.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embedding, - dense_prompt_embeddings=dense_embedding, - ) - - if self.use_stability_score: - scores = calculate_stability_score( - masks, self.model.mask_threshold, self.stability_score_offset - ) - - if self.return_single_mask: - masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) - - upscaled_masks = self.mask_postprocessing(masks, orig_im_size) - - if self.return_extra_metrics: - stability_scores = calculate_stability_score( - upscaled_masks, self.model.mask_threshold, self.stability_score_offset - ) - areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) - return upscaled_masks, scores, stability_scores, areas, masks - - return upscaled_masks, scores, masks diff --git a/segment_anything_hq/utils/transforms.py b/segment_anything_hq/utils/transforms.py deleted file mode 100644 index c08ba1e..0000000 --- a/segment_anything_hq/utils/transforms.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch -from torch.nn import functional as F -from torchvision.transforms.functional import resize, to_pil_image # type: ignore - -from copy import deepcopy -from typing import Tuple - - -class ResizeLongestSide: - """ - Resizes images to the longest side 'target_length', as well as provides - methods for resizing coordinates and boxes. Provides methods for - transforming both numpy array and batched torch tensors. - """ - - def __init__(self, target_length: int) -> None: - self.target_length = target_length - - def apply_image(self, image: np.ndarray) -> np.ndarray: - """ - Expects a numpy array with shape HxWxC in uint8 format. - """ - target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) - return np.array(resize(to_pil_image(image), target_size)) - - def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: - """ - Expects a numpy array of length 2 in the final dimension. Requires the - original image size in (H, W) format. - """ - old_h, old_w = original_size - new_h, new_w = self.get_preprocess_shape( - original_size[0], original_size[1], self.target_length - ) - coords = deepcopy(coords).astype(float) - coords[..., 0] = coords[..., 0] * (new_w / old_w) - coords[..., 1] = coords[..., 1] * (new_h / old_h) - return coords - - def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: - """ - Expects a numpy array shape Bx4. Requires the original image size - in (H, W) format. - """ - boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) - return boxes.reshape(-1, 4) - - def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: - """ - Expects batched images with shape BxCxHxW and float format. This - transformation may not exactly match apply_image. apply_image is - the transformation expected by the model. - """ - # Expects an image in BCHW format. May not exactly match apply_image. - target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) - return F.interpolate( - image, target_size, mode="bilinear", align_corners=False, antialias=True - ) - - def apply_coords_torch( - self, coords: torch.Tensor, original_size: Tuple[int, ...] - ) -> torch.Tensor: - """ - Expects a torch tensor with length 2 in the last dimension. Requires the - original image size in (H, W) format. - """ - old_h, old_w = original_size - new_h, new_w = self.get_preprocess_shape( - original_size[0], original_size[1], self.target_length - ) - coords = deepcopy(coords).to(torch.float) - coords[..., 0] = coords[..., 0] * (new_w / old_w) - coords[..., 1] = coords[..., 1] * (new_h / old_h) - return coords - - def apply_boxes_torch( - self, boxes: torch.Tensor, original_size: Tuple[int, ...] - ) -> torch.Tensor: - """ - Expects a torch tensor with shape Bx4. Requires the original image - size in (H, W) format. - """ - boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) - return boxes.reshape(-1, 4) - - @staticmethod - def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: - """ - Compute the output size given input size and target long side length. - """ - scale = long_side_length * 1.0 / max(oldh, oldw) - newh, neww = oldh * scale, oldw * scale - neww = int(neww + 0.5) - newh = int(newh + 0.5) - return (newh, neww)