merge: modules/control/proc/sam2/__init__.py

pull/4678/head
vladmandic 2026-03-12 14:16:50 +01:00
parent 643256f21d
commit 954d079472
1 changed files with 91 additions and 0 deletions

View File

@ -0,0 +1,91 @@
import cv2
import numpy as np
import torch
from PIL import Image
from modules import devices
from modules.shared import opts
from modules.control.util import HWC3, resize_image
class Sam2Detector:
def __init__(self, model, processor):
self.model = model
self.processor = processor
@classmethod
def from_pretrained(cls, pretrained_model_or_path="facebook/sam2.1-hiera-large", cache_dir=None, local_files_only=False):
from transformers import AutoProcessor, AutoModelForMaskGeneration
processor = AutoProcessor.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only)
model = AutoModelForMaskGeneration.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only, use_safetensors=True).to(devices.device).eval()
return cls(model, processor)
def _generate_grid_points(self, h, w, n_points_per_side=32):
ys = np.linspace(0, h, n_points_per_side + 2)[1:-1]
xs = np.linspace(0, w, n_points_per_side + 2)[1:-1]
points = np.array([[x, y] for y in ys for x in xs], dtype=np.float64)
labels = np.ones(len(points), dtype=np.int64)
return points, labels
def _colorize_masks(self, masks, h, w):
from numpy.random import default_rng
gen = default_rng(42)
canvas = np.zeros((h, w, 3), dtype=np.uint8)
if len(masks) == 0:
return canvas
sorted_masks = sorted(enumerate(masks), key=lambda x: x[1].sum(), reverse=True)
for _idx, mask in sorted_masks:
color = gen.integers(50, 256, size=3, dtype=np.uint8)
canvas[mask] = color
return canvas
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
import torch.nn.functional as F
self.model.to(devices.device)
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
h, w = input_image.shape[:2]
pil_image = Image.fromarray(input_image)
points, labels = self._generate_grid_points(h, w, n_points_per_side=16)
# Process grid points in batches to avoid OOM
all_masks = []
all_scores = []
batch_size = 64
for i in range(0, len(points), batch_size):
batch_pts = points[i:i + batch_size]
batch_lbl = labels[i:i + batch_size]
# SAM2 expects 4-level nesting: [image, object, point, coords]
pts_nested = [[[pt] for pt in batch_pts.tolist()]]
lbl_nested = [[[lb] for lb in batch_lbl.tolist()]]
inputs = self.processor(images=pil_image, input_points=pts_nested, input_labels=lbl_nested, return_tensors="pt")
inputs = {k: v.to(devices.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with devices.inference_context():
outputs = self.model(**inputs)
# pred_masks: [1, N_objects, N_masks_per_obj, mask_h, mask_w] (low-res logits)
pred_masks = outputs.pred_masks[0] # [N_objects, N_masks, mask_h, mask_w]
iou_scores = outputs.iou_scores[0] if hasattr(outputs, 'iou_scores') else torch.ones(pred_masks.shape[:2], device=pred_masks.device)
for obj_idx in range(pred_masks.shape[0]):
best = iou_scores[obj_idx].argmax()
# Upscale low-res mask logits to original image size
mask_logits = pred_masks[obj_idx, best].unsqueeze(0).unsqueeze(0).float() # [1,1,mh,mw]
mask_upscaled = F.interpolate(mask_logits, size=(h, w), mode="bilinear", align_corners=False)
all_masks.append(mask_upscaled.squeeze().cpu().numpy() > 0.0)
all_scores.append(iou_scores[obj_idx, best].item())
if opts.control_move_processor:
self.model.to("cpu")
# Filter by IoU score
if len(all_masks) > 0:
scores = np.array(all_scores)
good = scores > 0.7
masks_np = [m for m, g in zip(all_masks, good) if g]
else:
masks_np = []
detected_map = self._colorize_masks(masks_np, h, w)
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
out_h, out_w = img.shape[:2]
detected_map = cv2.resize(detected_map, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map