diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index ca1776b..e693cea 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -142,6 +142,7 @@ class Script(scripts.Script): wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model") wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") + unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) unload_wd_models_button = gr.Button(value="Unload WD Models") # Function to load custom filter from file @@ -171,7 +172,7 @@ class Script(scripts.Script): use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) - return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix] + return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords] # Required to parse information from a string that is between () or has :##.## suffix def remove_attention(self, words): @@ -212,7 +213,7 @@ class Script(scripts.Script): return filtered_prompt - def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix): + def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords): raw_prompt = p.prompt interrogator = "" @@ -262,7 +263,10 @@ class Script(scripts.Script): p.prompt = f"{p.prompt}, {interrogator}" else: p.prompt = f"{interrogator}, {p.prompt}" - + + if unload_wd_models_afterwords and "WD (API)" in model_selection: + self.post_wd_api_unload() + print(f"Prompt: {p.prompt}") processed = process_images(p)