diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py index 4a88bcd..80919c9 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py @@ -25,8 +25,30 @@ class WaifuDiffusionTagger(): else: providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + def check_available_device(): + import torch + if torch.cuda.is_available(): + return 'cuda' + elif launch.is_installed("torch-directml"): + # This code cannot detect DirectML available device without pytorch-directml + try: + import torch_directml + torch_directml.device() + except: + pass + else: + return 'directml' + return 'cpu' + if not launch.is_installed("onnxruntime"): - launch.run_pip("install onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]") + dev = check_available_device() + if dev == 'cuda': + launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]") + elif dev == 'directml': + launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]") + else: + print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.') + launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]") import onnxruntime as ort self.model = ort.InferenceSession(path_model, providers=providers)