Fix dtype error in precision=half in macos
parent
8c3ead6913
commit
d673f9cf8c
|
|
@ -234,7 +234,7 @@ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=
|
||||||
input_reshaped = input.contiguous().view(
|
input_reshaped = input.contiguous().view(
|
||||||
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
||||||
|
|
||||||
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
|
out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps)
|
||||||
out = out.view(b, c, *input.size()[2:])
|
out = out.view(b, c, *input.size()[2:])
|
||||||
|
|
||||||
# post affine transform
|
# post affine transform
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue