pull/12/head
d8ahazard 2022-12-10 17:54:02 -06:00
parent 58bcff7f78
commit da90277e5e
3 changed files with 10 additions and 1 deletions

View File

@ -58,6 +58,7 @@ class ClipInterrogator(Interrogator):
model_name = "ViT-H-14/laion2b_s32b_b79k"
else:
model_name = "ViT-L-14/openai"
print(f"Loading CLIP model from {model_name}")
self.append_artist = append_artist
self.append_medium = append_medium
self.append_movement = append_movement

View File

@ -30,6 +30,7 @@ def on_ui_tabs():
sp_caption_clip = gr.Checkbox(label="Add CLIP results to Caption")
sp_clip_use_v2 = gr.Checkbox(label="Use v2 CLIP Model", value=True)
sp_clip_append_flavor = gr.Checkbox(label="Append Flavor tags from CLIP")
sp_clip_max_flavors = gr.Number(label="Max flavors to append.", value=4)
sp_clip_append_medium = gr.Checkbox(label="Append Medium tags from CLIP")
sp_clip_append_movement = gr.Checkbox(label="Append Movement tags from CLIP")
sp_clip_append_artist = gr.Checkbox(label="Append Artist tags from CLIP")
@ -91,6 +92,7 @@ def on_ui_tabs():
sp_caption_clip,
sp_clip_use_v2,
sp_clip_append_flavor,
sp_clip_max_flavors,
sp_clip_append_medium,
sp_clip_append_movement,
sp_clip_append_artist,

View File

@ -11,6 +11,7 @@ import modules.codeformer_model
import modules.gfpgan_model
import reallysafe
from clipcrop import CropClip
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator
from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator
from modules import shared, images, safe
@ -38,6 +39,7 @@ def preprocess(rename,
caption_clip,
clip_use_v2,
clip_append_flavor,
clip_max_flavors,
clip_append_medium,
clip_append_movement,
clip_append_artist,
@ -124,7 +126,7 @@ def preprocess(rename,
out_tags = []
if clip_interrogator is not None:
if caption_clip:
tags = clip_interrogator.interrogate(img)
tags = clip_interrogator.interrogate(img, max_flavors=clip_max_flavors)
for tag in tags:
#print(f"CLIPTag: {tag}")
out_tags.append(tag)
@ -206,6 +208,7 @@ def preprocess(rename,
image_index = 0
# Enumerate images
pil_features = list_features()
for index, src_image in enumerate(tqdm.tqdm(files)):
# Quit on cancel
if shared.state.interrupted:
@ -213,6 +216,9 @@ def preprocess(rename,
return msg, msg
filename = os.path.join(src, src_image)
if not is_image(filename):
continue
try:
img = Image.open(filename).convert("RGB")
except Exception as e: