From 3fe3e96543d59b2b5713f9ec003116137c7331d9 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sat, 28 Jan 2023 05:52:11 -0800 Subject: [PATCH] minor torch opti --- scripts/dynamic_thresholding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/dynamic_thresholding.py b/scripts/dynamic_thresholding.py index 073fab1..8217531 100644 --- a/scripts/dynamic_thresholding.py +++ b/scripts/dynamic_thresholding.py @@ -134,8 +134,8 @@ class CustomCFGDenoiser(sd_samplers.CFGDenoiser): assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches" cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:]) # conds_list shape is (batch, cond, 2) - weights = torch.tensor(conds_list).select(2, 1) - weights = weights.reshape(*weights.shape, 1, 1, 1).to(uncond.device) + weights = torch.tensor(conds_list, device=uncond.device).select(2, 1) + weights = weights.reshape(*weights.shape, 1, 1, 1) ### Normal first part of the CFG Scale logic, basically diff = cond_stacked - uncond.unsqueeze(1)