Merge pull request #37 from Disty0/main

Fix IPEX with torch.version.cuda hijack
pull/39/head
Vladimir Mandic 2023-09-29 08:25:34 -04:00 committed by GitHub
commit 2647efca74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 11 deletions

View File

@ -86,7 +86,12 @@ def get_gpu():
return {}
else:
try:
if torch.version.cuda:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': get_package_version('intel-extension-for-pytorch'),
}
elif torch.version.cuda:
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())}) ({torch.cuda.get_arch_list()[-1]}) {str(torch.cuda.get_device_capability(shared.device))}',
'cuda': torch.version.cuda,
@ -99,16 +104,9 @@ def get_gpu():
'hip': torch.version.hip,
}
else:
try:
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': get_package_version('intel-extension-for-pytorch'),
}
except Exception:
return {
'device': 'unknown'
}
return {
'device': 'unknown'
}
except Exception as e:
return { 'error': e }