diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 7c7a1a7..c63af40 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -22,7 +22,7 @@ def ultralytics_predict( device: str = "", classes: str = "", ) -> PredictOutput[float]: - from ultralytics import YOLO + from ultralytics import YOLO # noqa: PLC0415 model = YOLO(model_path) apply_classes(model, model_path, classes) @@ -61,11 +61,12 @@ def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image """ Parameters ---------- - masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). - The device can be CUDA, but `to_pil_image` takes care of that. + masks: torch.Tensor, dtype=torch.float32 or torch.uint8, shape=(N, H, W). + uint8 tensor is expected to have values 0 or 1 (not 0-255). shape: tuple[int, int] (W, H) of the original image """ + masks = masks.float() n = masks.shape[0] return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]