diff --git a/scripts/dataset_tag_editor/interrogators/blip2_captioning.py b/scripts/dataset_tag_editor/interrogators/blip2_captioning.py index 4d37ab3..b500efe 100644 --- a/scripts/dataset_tag_editor/interrogators/blip2_captioning.py +++ b/scripts/dataset_tag_editor/interrogators/blip2_captioning.py @@ -13,10 +13,10 @@ class BLIP2Captioning: def load(self): if self.model is None or self.processor is None: self.processor = Blip2Processor.from_pretrained( - self.MODEL_REPO, cache_dir=paths.model_path / "aesthetic" + self.MODEL_REPO, cache_dir=paths.model_path ) self.model = Blip2ForConditionalGeneration.from_pretrained( - self.MODEL_REPO, cache_dir=paths.model_path / "aesthetic" + self.MODEL_REPO, cache_dir=paths.model_path ).to(devices.device) def unload(self): diff --git a/scripts/dataset_tag_editor/interrogators/git_large_captioning.py b/scripts/dataset_tag_editor/interrogators/git_large_captioning.py index 4c8680a..bf9f223 100644 --- a/scripts/dataset_tag_editor/interrogators/git_large_captioning.py +++ b/scripts/dataset_tag_editor/interrogators/git_large_captioning.py @@ -1,6 +1,8 @@ from transformers import AutoProcessor, AutoModelForCausalLM from modules import shared, devices, lowvram +from scripts.paths import paths + # brought from https://huggingface.co/docs/transformers/main/en/model_doc/git and modified class GITLargeCaptioning: @@ -12,10 +14,12 @@ class GITLargeCaptioning: def load(self): if self.model is None or self.processor is None: - self.processor = AutoProcessor.from_pretrained(self.MODEL_REPO) - self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to( - shared.device + self.processor = AutoProcessor.from_pretrained( + self.MODEL_REPO, cache_dir=paths.model_path ) + self.model = AutoModelForCausalLM.from_pretrained( + self.MODEL_REPO, cache_dir=paths.model_path + ).to(shared.device) lowvram.send_everything_to_cpu() def unload(self): diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py index 75156d5..b66d41b 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py @@ -4,6 +4,8 @@ from typing import List, Tuple from modules import shared, devices import launch +from scripts.paths import paths + class WaifuDiffusionTagger: # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified @@ -24,7 +26,7 @@ class WaifuDiffusionTagger: if not self.model: path_model = huggingface_hub.hf_hub_download( - self.MODEL_REPO, self.MODEL_FILENAME + self.MODEL_REPO, self.MODEL_FILENAME, cache_dir=paths.model_path ) if ( "all" in shared.cmd_opts.use_cpu