mirror of https://github.com/vladmandic/automatic
parent
4a297d70f2
commit
27205c295e
|
|
@ -15,7 +15,7 @@ torch_version[0], torch_version[1] = int(torch_version[0]), int(torch_version[1]
|
|||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(devices.device).has_fp64
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, no-else-return
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, no-else-return, unused-argument, redefined-builtin, keyword-arg-before-vararg
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
|
@ -357,12 +357,12 @@ class DeviceProperties():
|
|||
# torch.Generator has to be a class for isinstance checks
|
||||
original_torch_Generator = torch.Generator
|
||||
class torch_Generator(original_torch_Generator):
|
||||
def __new__(self, device=None):
|
||||
def __new__(cls, device=None):
|
||||
# can't hijack __init__ because of C override so use return super().__new__
|
||||
if check_cuda(device):
|
||||
return super().__new__(self, return_xpu(device))
|
||||
return super().__new__(cls, return_xpu(device))
|
||||
else:
|
||||
return super().__new__(self, device)
|
||||
return super().__new__(cls, device)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
|
|
|
|||
Loading…
Reference in New Issue