diff --git a/scripts/npw.py b/scripts/npw.py index 19c315a..16b6b5f 100644 --- a/scripts/npw.py +++ b/scripts/npw.py @@ -85,6 +85,8 @@ class Script(scripts.Script): def denoiser_callback(self, params): def concat_and_lerp(empty, tensor, weight): + if empty.shape[0] != tensor.shape[0]: + empty = empty.expand(tensor.shape[0], *empty.shape[1:]) if tensor.shape[1] > empty.shape[1]: num_concatenations = tensor.shape[1] // empty.shape[1] empty_concat = torch.cat([empty] * num_concatenations, dim=1)