enable wd-v14-tagger for AMD GPU users (#43)

pull/48/head
toshiaki1729 2023-02-26 12:24:27 +09:00
parent 77a4c42ab0
commit c0aa72e100
1 changed files with 23 additions and 1 deletions

View File

@ -25,8 +25,30 @@ class WaifuDiffusionTagger():
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
def check_available_device():
import torch
if torch.cuda.is_available():
return 'cuda'
elif launch.is_installed("torch-directml"):
# This code cannot detect DirectML available device without pytorch-directml
try:
import torch_directml
torch_directml.device()
except:
pass
else:
return 'directml'
return 'cpu'
if not launch.is_installed("onnxruntime"):
launch.run_pip("install onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]")
dev = check_available_device()
if dev == 'cuda':
launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]")
elif dev == 'directml':
launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]")
else:
print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.')
launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]")
import onnxruntime as ort
self.model = ort.InferenceSession(path_model, providers=providers)