diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index ce4614a..c2b8897 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -1,21 +1,7 @@ import gradio as gr -from modules import scripts, devices, lowvram, shared -from clip_interrogator import Config, Interrogator +from modules import scripts, shared, deepbooru from modules.processing import process_images -ci = None - - -def unload(): - global ci - if ci is not None: - print("Offloading CLIP Interrogator...") - ci.caption_model = ci.caption_model.to(devices.cpu) - ci.clip_model = ci.clip_model.to(devices.cpu) - ci.caption_offloaded = True - ci.clip_offloaded = True - devices.torch_gc() - class Script(scripts.Script): def title(self): @@ -29,34 +15,18 @@ class Script(scripts.Script): prompt_weight = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="interrogator weight" ) - mode = gr.Dropdown(["classic", "fast"], label="mode", value="classic") - btn = gr.Button(value="unload models") - btn.click(unload) - return [in_front, mode, prompt_weight] + use_deepbooru = gr.Checkbox(label="Use deepbooru") + return [in_front, prompt_weight, use_deepbooru] - def run(self, p, in_front, mode, prompt_weight): - global ci - if ci is None: - config = Config( - device=devices.get_optimal_device(), - cache_path="models/clip-interrogator", - clip_model_name="ViT-L-14/openai", - caption_model_name="blip-base", - ) - ci = Interrogator(config) - try: - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - devices.torch_gc() - if mode == "classic": - prompt = ci.interrogate_classic(p.init_images[0]) - elif mode == "fast": - prompt = ci.interrogate_fast(p.init_images[0]) - if in_front: - p.prompt = f"{p.prompt}, ({prompt}:{prompt_weight})" - else: - p.prompt = f"({prompt}:{prompt_weight}), {p.prompt}" - print(prompt) - except RuntimeError as e: - print(e) + def run(self, p, in_front, prompt_weight, use_deepbooru): + prompt = "" + if use_deepbooru: + prompt = deepbooru.model.tag(p.init_images[0]) + else: + prompt = shared.interrogator.interrogate(p.init_images[0]) + if in_front: + p.prompt = f"{p.prompt}, ({prompt}:{prompt_weight})" + else: + p.prompt = f"({prompt}:{prompt_weight}), {p.prompt}" + print(f"Prompt: {p.prompt}") return process_images(p)