fix: missing parenthesis in ipex autocast

pull/2680/head
KerfuffleV2 2024-01-04 00:30:31 -07:00
parent c60da70474
commit 324e728bba
1 changed files with 1 additions and 1 deletions

View File

@ -28,7 +28,7 @@ def return_xpu(device):
# Autocast
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda" or args[0] == "xpu":
if len(args) > 0 and (args[0] == "cuda" or args[0] == "xpu"):
if "dtype" in kwargs:
return original_autocast("xpu", *args[1:], **kwargs)
else: