diff --git a/scripts/dynamic_thresholding.py b/scripts/dynamic_thresholding.py index 073fab1..8217531 100644 --- a/scripts/dynamic_thresholding.py +++ b/scripts/dynamic_thresholding.py @@ -134,8 +134,8 @@ class CustomCFGDenoiser(sd_samplers.CFGDenoiser): assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches" cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:]) # conds_list shape is (batch, cond, 2) - weights = torch.tensor(conds_list).select(2, 1) - weights = weights.reshape(*weights.shape, 1, 1, 1).to(uncond.device) + weights = torch.tensor(conds_list, device=uncond.device).select(2, 1) + weights = weights.reshape(*weights.shape, 1, 1, 1) ### Normal first part of the CFG Scale logic, basically diff = cond_stacked - uncond.unsqueeze(1)