From da5053bc2e688ed624625e9bb2987f870930e03e Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Sun, 26 May 2024 13:17:17 +0900 Subject: [PATCH] Update waifu_diffusion_tagger_timm.py fix #100 --- .../interrogators/waifu_diffusion_tagger_timm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py index 0e7599b..6bffc51 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py @@ -81,7 +81,7 @@ class WaifuDiffusionTaggerTimm: with torch.inference_mode(): features = self.model.forward(image_t) - probs = F.sigmoid(features).detach().cpu() + probs = F.sigmoid(features).detach().cpu().numpy() labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))