mirror of https://github.com/vladmandic/automatic
92 lines
4.6 KiB
Python
92 lines
4.6 KiB
Python
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from modules import devices
|
|
from modules.shared import opts
|
|
from modules.control.util import HWC3, resize_image
|
|
|
|
|
|
class Sam2Detector:
|
|
def __init__(self, model, processor):
|
|
self.model = model
|
|
self.processor = processor
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_or_path="facebook/sam2.1-hiera-large", cache_dir=None, local_files_only=False):
|
|
from transformers import AutoProcessor, AutoModelForMaskGeneration
|
|
processor = AutoProcessor.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only)
|
|
model = AutoModelForMaskGeneration.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only, use_safetensors=True).to(devices.device).eval()
|
|
return cls(model, processor)
|
|
|
|
def _generate_grid_points(self, h, w, n_points_per_side=32):
|
|
ys = np.linspace(0, h, n_points_per_side + 2)[1:-1]
|
|
xs = np.linspace(0, w, n_points_per_side + 2)[1:-1]
|
|
points = np.array([[x, y] for y in ys for x in xs], dtype=np.float64)
|
|
labels = np.ones(len(points), dtype=np.int64)
|
|
return points, labels
|
|
|
|
def _colorize_masks(self, masks, h, w):
|
|
from numpy.random import default_rng
|
|
gen = default_rng(42)
|
|
canvas = np.zeros((h, w, 3), dtype=np.uint8)
|
|
if len(masks) == 0:
|
|
return canvas
|
|
sorted_masks = sorted(enumerate(masks), key=lambda x: x[1].sum(), reverse=True)
|
|
for _idx, mask in sorted_masks:
|
|
color = gen.integers(50, 256, size=3, dtype=np.uint8)
|
|
canvas[mask] = color
|
|
return canvas
|
|
|
|
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
|
|
import torch.nn.functional as F
|
|
self.model.to(devices.device)
|
|
if not isinstance(input_image, np.ndarray):
|
|
input_image = np.array(input_image, dtype=np.uint8)
|
|
input_image = HWC3(input_image)
|
|
input_image = resize_image(input_image, detect_resolution)
|
|
h, w = input_image.shape[:2]
|
|
pil_image = Image.fromarray(input_image)
|
|
points, labels = self._generate_grid_points(h, w, n_points_per_side=16)
|
|
# Process grid points in batches to avoid OOM
|
|
all_masks = []
|
|
all_scores = []
|
|
batch_size = 64
|
|
for i in range(0, len(points), batch_size):
|
|
batch_pts = points[i:i + batch_size]
|
|
batch_lbl = labels[i:i + batch_size]
|
|
# SAM2 expects 4-level nesting: [image, object, point, coords]
|
|
pts_nested = [[[pt] for pt in batch_pts.tolist()]]
|
|
lbl_nested = [[[lb] for lb in batch_lbl.tolist()]]
|
|
inputs = self.processor(images=pil_image, input_points=pts_nested, input_labels=lbl_nested, return_tensors="pt")
|
|
inputs = {k: v.to(devices.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
with devices.inference_context():
|
|
outputs = self.model(**inputs)
|
|
# pred_masks: [1, N_objects, N_masks_per_obj, mask_h, mask_w] (low-res logits)
|
|
pred_masks = outputs.pred_masks[0] # [N_objects, N_masks, mask_h, mask_w]
|
|
iou_scores = outputs.iou_scores[0] if hasattr(outputs, 'iou_scores') else torch.ones(pred_masks.shape[:2], device=pred_masks.device)
|
|
for obj_idx in range(pred_masks.shape[0]):
|
|
best = iou_scores[obj_idx].argmax()
|
|
# Upscale low-res mask logits to original image size
|
|
mask_logits = pred_masks[obj_idx, best].unsqueeze(0).unsqueeze(0).float() # [1,1,mh,mw]
|
|
mask_upscaled = F.interpolate(mask_logits, size=(h, w), mode="bilinear", align_corners=False)
|
|
all_masks.append(mask_upscaled.squeeze().cpu().numpy() > 0.0)
|
|
all_scores.append(iou_scores[obj_idx, best].item())
|
|
if opts.control_move_processor:
|
|
self.model.to("cpu")
|
|
# Filter by IoU score
|
|
if len(all_masks) > 0:
|
|
scores = np.array(all_scores)
|
|
good = scores > 0.7
|
|
masks_np = [m for m, g in zip(all_masks, good) if g]
|
|
else:
|
|
masks_np = []
|
|
detected_map = self._colorize_masks(masks_np, h, w)
|
|
detected_map = HWC3(detected_map)
|
|
img = resize_image(input_image, image_resolution)
|
|
out_h, out_w = img.shape[:2]
|
|
detected_map = cv2.resize(detected_map, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
|
|
if output_type == "pil":
|
|
detected_map = Image.fromarray(detected_map)
|
|
return detected_map
|