Fix IPEX with torch.version.cuda hijack

pull/37/head
Disty0 2023-09-29 14:13:07 +03:00
parent 34ada0db1b
commit f8274cdb08
1 changed files with 13 additions and 11 deletions

View File

@ -86,7 +86,16 @@ def get_gpu():
return {}
else:
try:
if torch.version.cuda:
if hasattr(torch, "xpu") and torch.xpu.is_available():
try:
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': get_package_version('intel-extension-for-pytorch'),
}
except Exception as e:
return { 'error': e }
elif torch.version.cuda:
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())}) ({torch.cuda.get_arch_list()[-1]}) {str(torch.cuda.get_device_capability(shared.device))}',
'cuda': torch.version.cuda,
@ -99,16 +108,9 @@ def get_gpu():
'hip': torch.version.hip,
}
else:
try:
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': get_package_version('intel-extension-for-pytorch'),
}
except Exception:
return {
'device': 'unknown'
}
return {
'device': 'unknown'
}
except Exception as e:
return { 'error': e }