Use torch.t() instead of misused torch.transpose()

main
File_xor 2023-04-28 00:32:07 +09:00
parent 862d7aeaeb
commit 3042d9534a
1 changed files with 2 additions and 2 deletions

View File

@ -57,7 +57,7 @@ class Clip_IO(scripts.Script):
dir = os.path.dirname(filename)
if not os.path.exists(dir): os.makedirs(dir)
if not filename.endswith(".pt"): filename += ".pt"
torch.save(embeddings.transpose() if transpose else embeddings, filename)
torch.save(embeddings.t() if transpose else embeddings, filename)
pass
def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool):
@ -68,7 +68,7 @@ class Clip_IO(scripts.Script):
dir = os.path.dirname(filename)
if not os.path.exists(dir): os.makedirs(dir)
if not filename.endswith(".csv"): filename += ".csv"
embeddings_numpy = embeddings[0].transpose().to("cpu").numpy() if transpose else embeddings[0].to("cpu").numpy()
embeddings_numpy = embeddings[0].t().to("cpu").numpy() if transpose else embeddings[0].to("cpu").numpy()
embeddings_dataframe = pandas.DataFrame(embeddings_numpy)
embeddings_dataframe.to_csv(filename)
pass