Signed-off-by: vladmandic <mandic00@live.com>
pull/4713/head
vladmandic 2026-03-26 08:11:42 +01:00
parent 4a297d70f2
commit 27205c295e
1 changed files with 4 additions and 4 deletions

View File

@ -15,7 +15,7 @@ torch_version[0], torch_version[1] = int(torch_version[0]), int(torch_version[1]
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(devices.device).has_fp64
# pylint: disable=protected-access, missing-function-docstring, line-too-long, no-else-return
# pylint: disable=protected-access, missing-function-docstring, line-too-long, no-else-return, unused-argument, redefined-builtin, keyword-arg-before-vararg
def return_false(*args, **kwargs):
return False
@ -357,12 +357,12 @@ class DeviceProperties():
# torch.Generator has to be a class for isinstance checks
original_torch_Generator = torch.Generator
class torch_Generator(original_torch_Generator):
def __new__(self, device=None):
def __new__(cls, device=None):
# can't hijack __init__ because of C override so use return super().__new__
if check_cuda(device):
return super().__new__(self, return_xpu(device))
return super().__new__(cls, return_xpu(device))
else:
return super().__new__(self, device)
return super().__new__(cls, device)
# Hijack Functions: