diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py index 0e7599b..6bffc51 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py @@ -81,7 +81,7 @@ class WaifuDiffusionTaggerTimm: with torch.inference_mode(): features = self.model.forward(image_t) - probs = F.sigmoid(features).detach().cpu() + probs = F.sigmoid(features).detach().cpu().numpy() labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))