diff --git a/scripts/main.py b/scripts/main.py index d529e60..ba140ac 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from modules import scripts, processing, shared, images, devices +from modules import scripts, processing, shared, images, devices, ui import gradio import requests import time @@ -49,6 +49,7 @@ class Main(scripts.Script): "DPM++ 2M Karras": "k_dpmpp_2m" } KARRAS = {"LMS Karras", "DPM2 Karras", "DPM2 a Karras", "DPM++ 2S a Karras", "DPM++ 2M Karras"} + POST_PROCESSINGS = {"CodeFormers (Restore faces)", "GFPGAN (Restore faces)", "RealESRGAN_x4plus (Upscale)"} #settings tab api_endpoint = "https://stablehorde.net/api" api_key = "0000000000" @@ -62,8 +63,7 @@ class Main(scripts.Script): def show(self, is_img2img): return True - def ui(self, is_img2img): - nsfw = gradio.Checkbox(False, label="NSFW") + def load_models(self): models = requests.get("{}/v2/status/models".format(self.api_endpoint)) if models.status_code == 200: @@ -74,17 +74,62 @@ class Main(scripts.Script): models = [] models.insert(0, "Random") - model = gradio.Dropdown(models, value="Random", label="Model") - seed_variation = gradio.Number(value=1, label="Seed variation", precision=0) - return [nsfw, model, seed_variation] + return models - def run(self, p, nsfw, model, seed_variation): + def ui(self, is_img2img): + nsfw = gradio.Checkbox(False, label="NSFW") + + with gradio.Box(): + with gradio.Row(elem_id="model_row"): + model = gradio.Dropdown(self.load_models(), value="Random", label="Model") + model.style(container=False) + update = gradio.Button(ui.refresh_symbol, elem_id="update_model") + + with gradio.Box(): + seed_variation = gradio.Number(value=1, label="Seed variation", precision=0) + seed_variation.style(container=False) + + with gradio.Box(): + with gradio.Row(elem_id="post_processing_row"): + post_processing_1 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #1") + post_processing_1.style(container=False) + post_processing_2 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #2", interactive=False) + post_processing_2.style(container=False) + post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False) + post_processing_3.style(container=False) + + def update_click(): + return gradio.update(choices=self.load_models(), value="Random") + + def post_processing_1_change(value_1): + return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None")) + + def post_processing_2_change(value_1, value_2): + return gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1, value_2}), value="None", interactive=value_2 != "None") + + update.click(fn=update_click, outputs=model) + post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3]) + post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3) + return [nsfw, model, seed_variation, post_processing_1, post_processing_2, post_processing_3] + + def run(self, p, nsfw, model, seed_variation, post_processing_1, post_processing_2, post_processing_3): if model != "Random": model = model.split("(")[0].rstrip() - return self.process_images(p, nsfw, model, seed_variation) + post_processing = [] - def process_images(self, p, nsfw, model, seed_variation): + if post_processing_1 != "None": + post_processing.append(post_processing_1.split("(")[0].rstrip()) + + if post_processing_2 != "None": + post_processing.append(post_processing_2.split("(")[0].rstrip()) + + if post_processing_3 != "None": + post_processing.append(post_processing_3.split("(")[0].rstrip()) + + return self.process_images(p, nsfw, model, seed_variation, post_processing) + + def process_images(self, p, nsfw, model, seed_variation, post_processing): # Copyright (C) 2022 AUTOMATIC1111 stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()} @@ -93,7 +138,7 @@ class Main(scripts.Script): for k, v in p.override_settings.items(): setattr(shared.opts, k, v) - res = self.process_images_inner(p, nsfw, model, seed_variation) + res = self.process_images_inner(p, nsfw, model, seed_variation, post_processing) finally: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): @@ -101,7 +146,7 @@ class Main(scripts.Script): return res - def process_images_inner(self, p, nsfw, model, seed_variation): + def process_images_inner(self, p, nsfw, model, seed_variation, post_processing): # Copyright (C) 2022 AUTOMATIC1111 if type(p.prompt) == list: @@ -178,7 +223,7 @@ class Main(scripts.Script): if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - x_samples_ddim = self.process_batch_horde(p, nsfw, model, seed_variation, prompts[0], negative_prompts[0], seeds[0]) + x_samples_ddim = self.process_batch_horde(p, nsfw, model, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0]) if x_samples_ddim is None: del x_samples_ddim @@ -250,7 +295,7 @@ class Main(scripts.Script): return res - def process_batch_horde(self, p, nsfw, model, seed_variation, prompt, negative_prompt, seed): + def process_batch_horde(self, p, nsfw, model, seed_variation, post_processing, prompt, negative_prompt, seed): payload = { "prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt, "params": { @@ -283,14 +328,6 @@ class Main(scripts.Script): #img2img/inpainting - post_processing = [] - - #upscale RealESRGAN_x4plus - - if p.restore_faces: - #CodeFormers - post_processing.append("GFPGAN") - if len(post_processing) > 0: payload["params"]["post_processing"] = post_processing diff --git a/style.css b/style.css new file mode 100644 index 0000000..6c2ce9d --- /dev/null +++ b/style.css @@ -0,0 +1,10 @@ +#update_model { + min-width: auto; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; +} + +#model_row { + gap: 0.5rem; +}