Fix dtype error in precision=half in macos

pull/402/head
Kohaku-Blueleaf 2024-08-07 18:28:04 +09:00 committed by GitHub
parent 8c3ead6913
commit d673f9cf8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -234,7 +234,7 @@ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=
input_reshaped = input.contiguous().view( input_reshaped = input.contiguous().view(
1, int(b * num_groups), channel_in_group, *input.size()[2:]) 1, int(b * num_groups), channel_in_group, *input.size()[2:])
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps)
out = out.view(b, c, *input.size()[2:]) out = out.view(b, c, *input.size()[2:])
# post affine transform # post affine transform