enable wd-v14-tagger for AMD GPU users (#43)
parent
77a4c42ab0
commit
c0aa72e100
|
|
@ -25,8 +25,30 @@ class WaifuDiffusionTagger():
|
|||
else:
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
def check_available_device():
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif launch.is_installed("torch-directml"):
|
||||
# This code cannot detect DirectML available device without pytorch-directml
|
||||
try:
|
||||
import torch_directml
|
||||
torch_directml.device()
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
return 'directml'
|
||||
return 'cpu'
|
||||
|
||||
if not launch.is_installed("onnxruntime"):
|
||||
launch.run_pip("install onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]")
|
||||
dev = check_available_device()
|
||||
if dev == 'cuda':
|
||||
launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]")
|
||||
elif dev == 'directml':
|
||||
launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]")
|
||||
else:
|
||||
print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.')
|
||||
launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]")
|
||||
import onnxruntime as ort
|
||||
self.model = ort.InferenceSession(path_model, providers=providers)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue