diff --git a/README.md b/README.md index 2676a6b..12f6acc 100644 --- a/README.md +++ b/README.md @@ -41,40 +41,10 @@ I didn't make any models, and most of the code was heavily borrowed from the [De │ ... ``` - + - #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)* Please ask the original author MrSmilingWolf#5991 for questions related to model or additional training. - - Quote from MrSmilingWolf: - > Based on validation score I'd say this is pretty much production grade. - > - > I've launched a longer training run (50 epochs, ETA: 9 days), mainly to check how much more can be squeezed out of it, but I'm fairly confident this can be plugged into a real inference pipeline already. - > - > I'm also finetuning the ConvNext network, but so far ViT has always coped better with less popular classes, so I'm edging my bets on this one. - OTOH, ensembling seems to give a decent boost in validation metrics, so if we ever want to do that, I'll be ready." - 1. Download the compressed model file. - 1. Join the [SD Training Labs](https://discord.gg/zUDeSwMf2k) discord server - 1. Click mega.nz link from [this message](https://discord.com/channels/1038249716149928046/1038249717001359402/1041160494150594671) - - 1. Unzip and move all files to the cloned repository. - - 1. The file structure should look like: - ``` - extensions/ - └╴wd14-tagger/ - ├╴2022_0000_0899_6549/ - │ └╴selected_tags.csv - │ - ├╴networks/ - │ └╴ViTB16_11_03_2022_07h05m53s/ - │ └╴ ... - │ - ├╴scripts/ - │ └╴tagger.py - │ - ... - ``` 1. Start or restart the WebUI. - or you can press refresh button after *Interrogator* dropdown box. diff --git a/scripts/tagger.py b/scripts/tagger.py index 97a373d..da7ad8a 100644 --- a/scripts/tagger.py +++ b/scripts/tagger.py @@ -30,12 +30,7 @@ def refresh_interrogators() -> List[str]: interrogators = {} # load waifu diffusion 1.4 tagger models - # TODO: temporary code, should use shared.models_path later - if os.path.isdir(Path(script_dir, '2022_0000_0899_6549')): - interrogators['wd14'] = WaifuDiffusionInterrogator( - Path(script_dir, 'networks', 'ViTB16_11_03_2022_07h05m53s'), - Path(script_dir, '2022_0000_0899_6549', 'selected_tags.csv') - ) + interrogators['wd14'] = WaifuDiffusionInterrogator() # load deepdanbooru project os.makedirs( diff --git a/tagger/interrogator.py b/tagger/interrogator.py index 23deaa1..bfc6ec1 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -8,6 +8,9 @@ from typing import Tuple, List, Dict from io import BytesIO from PIL import Image +from pathlib import Path +from huggingface_hub import hf_hub_download + from modules import shared from modules.deepbooru import re_special as tag_escape_pattern @@ -161,16 +164,38 @@ class DeepDanbooruInterrogator(Interrogator): class WaifuDiffusionInterrogator(Interrogator): - def __init__(self, model_path: os.PathLike, tags_path: os.PathLike) -> None: - print(f'Loading Waifu Diffusion tagger model from {str(model_path)}') + def __init__(self) -> None: + self.model = None + def download(self) -> Tuple[os.PathLike, os.PathLike]: + repo = "SmilingWolf/wd-v1-4-vit-tagger" + model_files = [ + {"filename": "saved_model.pb", "subfolder": ""}, + {"filename": "keras_metadata.pb", "subfolder": ""}, + {"filename": "variables.index", "subfolder": "variables"}, + {"filename": "variables.data-00000-of-00001", "subfolder": "variables"}, + ] + + print(f'Downloading Waifu Diffusion tagger model files from {repo}') + model_file_paths = [] + for elem in model_files: + model_file_paths.append(Path(hf_hub_download(repo, **elem))) + + model_path = model_file_paths[0].parents[0] + tags_path = Path(hf_hub_download(repo, filename="selected_tags.csv")) + return model_path, tags_path + + def load(self) -> None: + model_path, tags_path = self.download() + + print(f'Loading Waifu Diffusion tagger model from {str(model_path)}') with tf.device(device_name): self.model = tf.keras.models.load_model( model_path, compile=False ) - self.tags = pd.read_csv(tags_path) + self.tags = pd.read_csv(tags_path) def interrogate( self, @@ -188,6 +213,10 @@ class WaifuDiffusionInterrogator(Interrogator): image = image.astype(np.float32) image = np.expand_dims(image, 0) + # init model + if self.model is None: + self.load() + # evaluate model confidents = self.model(image, training=False)