From d673f9cf8c87292e5e18a541ebaaf81c916a1b49 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 7 Aug 2024 18:28:04 +0900 Subject: [PATCH] Fix dtype error in precision=half in macos --- scripts/tilevae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tilevae.py b/scripts/tilevae.py index 199eb7c..5f4ed4a 100644 --- a/scripts/tilevae.py +++ b/scripts/tilevae.py @@ -234,7 +234,7 @@ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps= input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) - out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) + out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps) out = out.view(b, c, *input.size()[2:]) # post affine transform