minor torch opti

pull/9/head
Alex "mcmonkey" Goodwin 2023-01-28 05:52:11 -08:00
parent b70ed89078
commit 3fe3e96543
1 changed files with 2 additions and 2 deletions

View File

@ -134,8 +134,8 @@ class CustomCFGDenoiser(sd_samplers.CFGDenoiser):
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
# conds_list shape is (batch, cond, 2)
weights = torch.tensor(conds_list).select(2, 1)
weights = weights.reshape(*weights.shape, 1, 1, 1).to(uncond.device)
weights = torch.tensor(conds_list, device=uncond.device).select(2, 1)
weights = weights.reshape(*weights.shape, 1, 1, 1)
### Normal first part of the CFG Scale logic, basically
diff = cond_stacked - uncond.unsqueeze(1)