feat: ultralytics yolo world model

pull/539/head
Dowon 2024-02-26 21:20:11 +09:00
parent e49667ddc7
commit 1506361898
1 changed files with 9 additions and 0 deletions

View File

@ -15,10 +15,12 @@ def ultralytics_predict(
image: Image.Image,
confidence: float = 0.3,
device: str = "",
classes: str = "",
) -> PredictOutput:
from ultralytics import YOLO
model = YOLO(model_path)
apply_classes(model, model_path, classes)
pred = model(image, conf=confidence, device=device)
bboxes = pred[0].boxes.xyxy.cpu().numpy()
@ -37,6 +39,13 @@ def ultralytics_predict(
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
def apply_classes(model, model_path: str | Path, classes: str):
if not classes or not Path(model_path).stem.endswith("world"):
return
parsed = [c.strip() for c in classes.split(",")]
model.set_classes(parsed)
def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
"""
Parameters