Update waifu_diffusion_tagger_timm.py

fix #100
main
toshiaki1729 2024-05-26 13:17:17 +09:00
parent 94cc5a2a6b
commit da5053bc2e
1 changed files with 1 additions and 1 deletions

View File

@ -81,7 +81,7 @@ class WaifuDiffusionTaggerTimm:
with torch.inference_mode():
features = self.model.forward(image_t)
probs = F.sigmoid(features).detach().cpu()
probs = F.sigmoid(features).detach().cpu().numpy()
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))