Fix IPEX with torch.version.cuda hijack
parent
34ada0db1b
commit
f8274cdb08
|
|
@ -86,7 +86,16 @@ def get_gpu():
|
|||
return {}
|
||||
else:
|
||||
try:
|
||||
if torch.version.cuda:
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
try:
|
||||
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
|
||||
return {
|
||||
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
|
||||
'ipex': get_package_version('intel-extension-for-pytorch'),
|
||||
}
|
||||
except Exception as e:
|
||||
return { 'error': e }
|
||||
elif torch.version.cuda:
|
||||
return {
|
||||
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())}) ({torch.cuda.get_arch_list()[-1]}) {str(torch.cuda.get_device_capability(shared.device))}',
|
||||
'cuda': torch.version.cuda,
|
||||
|
|
@ -99,16 +108,9 @@ def get_gpu():
|
|||
'hip': torch.version.hip,
|
||||
}
|
||||
else:
|
||||
try:
|
||||
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
|
||||
return {
|
||||
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
|
||||
'ipex': get_package_version('intel-extension-for-pytorch'),
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
'device': 'unknown'
|
||||
}
|
||||
return {
|
||||
'device': 'unknown'
|
||||
}
|
||||
except Exception as e:
|
||||
return { 'error': e }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue