parent
94cc5a2a6b
commit
da5053bc2e
|
|
@ -81,7 +81,7 @@ class WaifuDiffusionTaggerTimm:
|
|||
|
||||
with torch.inference_mode():
|
||||
features = self.model.forward(image_t)
|
||||
probs = F.sigmoid(features).detach().cpu()
|
||||
probs = F.sigmoid(features).detach().cpu().numpy()
|
||||
|
||||
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue