From c0aa72e100ae5647a24f9bd4163f02835d9d7aef Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Sun, 26 Feb 2023 12:24:27 +0900 Subject: [PATCH] enable wd-v14-tagger for AMD GPU users (#43) --- .../interrogators/waifu_diffusion_tagger.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py index 4a88bcd..80919c9 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py @@ -25,8 +25,30 @@ class WaifuDiffusionTagger(): else: providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + def check_available_device(): + import torch + if torch.cuda.is_available(): + return 'cuda' + elif launch.is_installed("torch-directml"): + # This code cannot detect DirectML available device without pytorch-directml + try: + import torch_directml + torch_directml.device() + except: + pass + else: + return 'directml' + return 'cpu' + if not launch.is_installed("onnxruntime"): - launch.run_pip("install onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]") + dev = check_available_device() + if dev == 'cuda': + launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]") + elif dev == 'directml': + launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]") + else: + print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.') + launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]") import onnxruntime as ort self.model = ort.InferenceSession(path_model, providers=providers)