Merge pull request #21 from SmilingWolf/master

Update WD tagger
pull/6/head
Sangha Lee 2022-12-06 12:23:15 +00:00 committed by GitHub
commit c6b00a4100
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 40 deletions

View File

@ -41,40 +41,10 @@ I didn't make any models, and most of the code was heavily borrowed from the [De
...
```
- #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)*
Please ask the original author MrSmilingWolf#5991 for questions related to model or additional training.
Quote from MrSmilingWolf:
> Based on validation score I'd say this is pretty much production grade.
>
> I've launched a longer training run (50 epochs, ETA: 9 days), mainly to check how much more can be squeezed out of it, but I'm fairly confident this can be plugged into a real inference pipeline already.
>
> I'm also finetuning the ConvNext network, but so far ViT has always coped better with less popular classes, so I'm edging my bets on this one.
OTOH, ensembling seems to give a decent boost in validation metrics, so if we ever want to do that, I'll be ready."
1. Download the compressed model file.
1. Join the [SD Training Labs](https://discord.gg/zUDeSwMf2k) discord server
1. Click mega.nz link from [this message](https://discord.com/channels/1038249716149928046/1038249717001359402/1041160494150594671)
1. Unzip and move all files to the cloned repository.
1. The file structure should look like:
```
extensions/
└╴wd14-tagger/
├╴2022_0000_0899_6549/
│ └╴selected_tags.csv
├╴networks/
│ └╴ViTB16_11_03_2022_07h05m53s/
│ └╴ ...
├╴scripts/
│ └╴tagger.py
...
```
1. Start or restart the WebUI.
- or you can press refresh button after *Interrogator* dropdown box.

View File

@ -30,12 +30,7 @@ def refresh_interrogators() -> List[str]:
interrogators = {}
# load waifu diffusion 1.4 tagger models
# TODO: temporary code, should use shared.models_path later
if os.path.isdir(Path(script_dir, '2022_0000_0899_6549')):
interrogators['wd14'] = WaifuDiffusionInterrogator(
Path(script_dir, 'networks', 'ViTB16_11_03_2022_07h05m53s'),
Path(script_dir, '2022_0000_0899_6549', 'selected_tags.csv')
)
interrogators['wd14'] = WaifuDiffusionInterrogator()
# load deepdanbooru project
os.makedirs(

View File

@ -8,6 +8,9 @@ from typing import Tuple, List, Dict
from io import BytesIO
from PIL import Image
from pathlib import Path
from huggingface_hub import hf_hub_download
from modules import shared
from modules.deepbooru import re_special as tag_escape_pattern
@ -161,16 +164,38 @@ class DeepDanbooruInterrogator(Interrogator):
class WaifuDiffusionInterrogator(Interrogator):
def __init__(self, model_path: os.PathLike, tags_path: os.PathLike) -> None:
print(f'Loading Waifu Diffusion tagger model from {str(model_path)}')
def __init__(self) -> None:
self.model = None
def download(self) -> Tuple[os.PathLike, os.PathLike]:
repo = "SmilingWolf/wd-v1-4-vit-tagger"
model_files = [
{"filename": "saved_model.pb", "subfolder": ""},
{"filename": "keras_metadata.pb", "subfolder": ""},
{"filename": "variables.index", "subfolder": "variables"},
{"filename": "variables.data-00000-of-00001", "subfolder": "variables"},
]
print(f'Downloading Waifu Diffusion tagger model files from {repo}')
model_file_paths = []
for elem in model_files:
model_file_paths.append(Path(hf_hub_download(repo, **elem)))
model_path = model_file_paths[0].parents[0]
tags_path = Path(hf_hub_download(repo, filename="selected_tags.csv"))
return model_path, tags_path
def load(self) -> None:
model_path, tags_path = self.download()
print(f'Loading Waifu Diffusion tagger model from {str(model_path)}')
with tf.device(device_name):
self.model = tf.keras.models.load_model(
model_path,
compile=False
)
self.tags = pd.read_csv(tags_path)
self.tags = pd.read_csv(tags_path)
def interrogate(
self,
@ -188,6 +213,10 @@ class WaifuDiffusionInterrogator(Interrogator):
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
# init model
if self.model is None:
self.load()
# evaluate model
confidents = self.model(image, training=False)