minor torch opti
parent
b70ed89078
commit
3fe3e96543
|
|
@ -134,8 +134,8 @@ class CustomCFGDenoiser(sd_samplers.CFGDenoiser):
|
||||||
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
|
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
|
||||||
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
|
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
|
||||||
# conds_list shape is (batch, cond, 2)
|
# conds_list shape is (batch, cond, 2)
|
||||||
weights = torch.tensor(conds_list).select(2, 1)
|
weights = torch.tensor(conds_list, device=uncond.device).select(2, 1)
|
||||||
weights = weights.reshape(*weights.shape, 1, 1, 1).to(uncond.device)
|
weights = weights.reshape(*weights.shape, 1, 1, 1)
|
||||||
|
|
||||||
### Normal first part of the CFG Scale logic, basically
|
### Normal first part of the CFG Scale logic, basically
|
||||||
diff = cond_stacked - uncond.unsqueeze(1)
|
diff = cond_stacked - uncond.unsqueeze(1)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue