diff --git a/modules/intel/ipex/hijacks.py b/modules/intel/ipex/hijacks.py index 4a44920e2..0dbec3739 100644 --- a/modules/intel/ipex/hijacks.py +++ b/modules/intel/ipex/hijacks.py @@ -28,7 +28,7 @@ def return_xpu(device): # Autocast original_autocast = torch.autocast def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda" or args[0] == "xpu": + if len(args) > 0 and (args[0] == "cuda" or args[0] == "xpu"): if "dtype" in kwargs: return original_autocast("xpu", *args[1:], **kwargs) else: