# Stable Horde for Web UI, a Stable Horde client for AUTOMATIC1111's Stable Diffusion Web UI # Copyright (C) 2022 Natan Junges # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . from modules import scripts, processing, shared, images, devices, ui, sd_models import gradio import requests import time import PIL.Image import base64 import io import os.path import numpy import itertools import torch import json settings_file = os.path.join(scripts.basedir(), "settings.json") class FakeModel: sd_model_hash="" def __init__(self, name): self.sd_checkpoint_info = sd_models.CheckpointInfo("", "", "", name) class StableHordeError(Exception): pass class Main(scripts.Script): TITLE = "Run on Stable Horde" SAMPLERS = { "LMS": "k_lms", "LMS Karras": "k_lms", "Heun": "k_heun", "Euler": "k_euler", "Euler a": "k_euler_a", "DPM2": "k_dpm_2", "DPM2 Karras": "k_dpm_2", "DPM2 a": "k_dpm_2_a", "DPM2 a Karras": "k_dpm_2_a", "DPM fast": "k_dpm_fast", "DPM adaptive": "k_dpm_adaptive", "DPM++ 2S a": "k_dpmpp_2s_a", "DPM++ 2S a Karras": "k_dpmpp_2s_a", "DPM++ 2M": "k_dpmpp_2m", "DPM++ 2M Karras": "k_dpmpp_2m" } KARRAS = {"LMS Karras", "DPM2 Karras", "DPM2 a Karras", "DPM++ 2S a Karras", "DPM++ 2M Karras"} POST_PROCESSINGS = {"CodeFormers (Face restoration)", "GFPGAN (Face restoration)", "RealESRGAN_x4plus (Upscaling)"} def title(self): return self.TITLE def show(self, is_img2img): return True def load_settings(self): if os.path.exists(settings_file): with open(settings_file) as file: opts = json.load(file) self.api_endpoint = opts["api_endpoint"] self.api_key = opts["api_key"] self.censor_nsfw = opts["censor_nsfw"] self.trusted_workers = opts["trusted_workers"] self.workers = opts["workers"] else: self.api_endpoint = "https://stablehorde.net/api" self.api_key = "0000000000" self.censor_nsfw = True self.trusted_workers = True self.workers = [] def load_models(self): self.load_settings() try: models = requests.get("{}/v2/status/models".format(self.api_endpoint)) models = models.json() models.sort(key=lambda m: (-m["count"], m["name"])) models = ["{} ({})".format(m["name"], m["count"]) for m in models] except requests.ConnectionError: models = [] models.insert(0, "Random") self.models = models def ui(self, is_img2img): with gradio.Box(): with gradio.Row(elem_id="horde_model_row"): self.load_models() model = gradio.Dropdown(self.models, value="Random", label="Model") model.style(container=False) update = gradio.Button(ui.refresh_symbol, elem_id="horde_update_model") with gradio.Box(): with gradio.Row(): nsfw = gradio.Checkbox(False, label="NSFW") shared_laion = gradio.Checkbox(False, label="Share with LAION", interactive=not is_img2img) seed_variation = gradio.Slider(minimum=1, maximum=1000, value=1, step=1, label="Seed variation", elem_id="horde_seed_variation") with gradio.Box(): with gradio.Row(): post_processing_1 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #1") post_processing_1.style(container=False) post_processing_2 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #2", interactive=False) post_processing_2.style(container=False) post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False) post_processing_3.style(container=False) def update_click(): self.load_models() return gradio.update(choices=self.models, value="Random") def post_processing_1_change(value_1): return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=False)) def post_processing_2_change(value_1, value_2): return gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1, value_2}), value="None", interactive=value_2 != "None") update.click(fn=update_click, outputs=model) post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3]) post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3) def model_infotext(d): if "Model" in d and d["Model"] != "Random": try: return next(filter(lambda s: s.startswith("{} (".format(d["Model"])), self.models)) except StopIteration: pass return "Random" def post_processing_n_infotext(n): def post_processing_infotext(d): if "Post processing {}".format(n) in d: try: return next(filter(lambda s: s.startswith("{} (".format(d["Post processing {}".format(n)])), self.POST_PROCESSINGS)) except StopIteration: pass return "None" return post_processing_infotext self.infotext_fields = [ (model, model_infotext), (nsfw, "NSFW"), (shared_laion, "Share with LAION"), (seed_variation, "Seed variation"), (post_processing_1, post_processing_n_infotext(1)), (post_processing_2, post_processing_n_infotext(2)), (post_processing_3, post_processing_n_infotext(3)) ] return [model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3] def run(self, p, model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3): if model != "Random": model = model.split("(")[0].rstrip() post_processing = [] if post_processing_1 != "None": post_processing.append(post_processing_1.split("(")[0].rstrip()) if post_processing_2 != "None": post_processing.append(post_processing_2.split("(")[0].rstrip()) if post_processing_3 != "None": post_processing.append(post_processing_3.split("(")[0].rstrip()) return self.process_images(p, model, nsfw, shared_laion, int(seed_variation), post_processing) def process_images(self, p, model, nsfw, shared_laion, seed_variation, post_processing): # Copyright (C) 2022 AUTOMATIC1111 # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/d7aec59c4eb02f723b3d55c6f927a42e97acd679/modules/processing.py#L463-L490 stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()} try: for k, v in p.override_settings.items(): setattr(shared.opts, k, v) p.extra_generation_params = { "Model": model, "NSFW": nsfw, "Share with LAION": shared_laion, "Seed variation": seed_variation, "Post processing 1": (post_processing[0] if len(post_processing) >= 1 else None), "Post processing 2": (post_processing[1] if len(post_processing) >= 2 else None), "Post processing 3": (post_processing[2] if len(post_processing) >= 3 else None) } res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing) finally: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): setattr(shared.opts, k, v) return res def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing): # Copyright (C) 2022 AUTOMATIC1111 # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/d7aec59c4eb02f723b3d55c6f927a42e97acd679/modules/processing.py#L493-L687 fake_model = FakeModel(model) if type(p.prompt) == list: assert(len(p.prompt) > 0) else: assert p.prompt is not None devices.torch_gc() seed = processing.get_fixed_seed(p.seed) p.subseed = -1 p.subseed_strength = 0 p.seed_resize_from_h = 0 p.seed_resize_from_w = 0 if type(p.prompt) == list: p.all_prompts = list(itertools.chain.from_iterable((p.batch_size * [shared.prompt_styles.apply_styles_to_prompt(p.prompt[x * p.batch_size], p.styles)] for x in range(p.n_iter)))) else: p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)] if type(p.negative_prompt) == list: p.all_negative_prompts = list(itertools.chain.from_iterable((p.batch_size * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt[x * p.batch_size], p.styles)] for x in range(p.n_iter)))) else: p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)] if type(seed) == list: p.all_seeds = list(itertools.chain.from_iterable(([seed[x * p.batch_size] + y * seed_variation for y in range(p.batch_size)] for x in range(p.n_iter)))) else: p.all_seeds = [int(seed) + x * seed_variation for x in range(len(p.all_prompts))] p.all_subseeds = [-1 for _ in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0): old_model = shared.sd_model shared.sd_model = fake_model ret = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, {}, iteration, position_in_batch) shared.sd_model = old_model return ret with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: old_model = shared.sd_model shared.sd_model = fake_model processed = processing.Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) shared.sd_model = old_model if p.scripts is not None: p.scripts.process(p) infotexts = [] output_images = [] with torch.no_grad(): if shared.state.job_count == -1: shared.state.job_count = p.n_iter for n in range(p.n_iter): p.iteration = n if shared.state.skipped: shared.state.skipped = False if shared.state.interrupted: break prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] if len(prompts) == 0: break if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0]) if x_samples_ddim is None: break if p.scripts is not None: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * numpy.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(numpy.uint8) image = PIL.Image.fromarray(x_sample) p.extra_generation_params["Model"] = models[i] if p.color_corrections is not None and i < len(p.color_corrections): if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_color_correction: image_without_cc = processing.apply_overlay(image, p.paste_to, i, p.overlay_images) images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") image = processing.apply_color_correction(p.color_corrections[i], image) image = processing.apply_overlay(image, p.paste_to, i, p.overlay_images) if shared.opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=infotext(n, i), p=p) text = infotext(n, i) infotexts.append(text) if shared.opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) del x_samples_ddim devices.torch_gc() p.extra_generation_params["Model"] = model shared.state.job_no += 1 shared.state.sampling_step = 0 shared.state.current_image_sampling_step = 0 p.color_corrections = None index_of_first_image = 0 unwanted_grid_because_of_img_count = len(output_images) < 2 and shared.opts.grid_only_if_multiple if (shared.opts.return_grid or shared.opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count: grid = images.image_grid(output_images, p.batch_size) if shared.opts.return_grid: text = infotext() infotexts.insert(0, text) if shared.opts.enable_pnginfo: grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 if shared.opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], shared.opts.grid_format, info=infotext(), short_filename=not shared.opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() old_model = shared.sd_model shared.sd_model = fake_model res = processing.Processed(p, output_images, p.all_seeds[0], infotext(), subseed=-1, index_of_first_image=index_of_first_image, infotexts=infotexts) shared.sd_model = old_model if p.scripts is not None: p.scripts.postprocess(p, res) return res def process_batch_horde(self, p, model, nsfw, shared_laion, seed_variation, post_processing, prompt, negative_prompt, seed): payload = { "prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt, "params": { "sampler_name": self.SAMPLERS.get(p.sampler_name, "k_euler_a"), "cfg_scale": p.cfg_scale, "denoising_strength": p.denoising_strength if p.denoising_strength is not None else 0, "seed": str(seed), "height": p.height, "width": p.width, "seed_variation": seed_variation, "karras": p.sampler_name in self.KARRAS, "steps": p.steps, "n": p.batch_size } } self.load_settings() if nsfw: payload["nsfw"] = True elif self.censor_nsfw: payload["censor_nsfw"] = True if not self.trusted_workers: payload["trusted_workers"] = False if len(self.workers) > 0: payload["workers"] = self.workers if model != "Random": payload["models"] = [model] if self.is_img2img: buffer = io.BytesIO() p.init_images[0].save(buffer, format="WEBP") payload["source_image"] = base64.b64encode(buffer.getvalue()).decode() if p.image_mask is None: payload["source_processing"] = "img2img" else: payload["source_processing"] = "inpainting" buffer = io.BytesIO() p.image_mask.save(buffer, format="WEBP") payload["source_mask"] = base64.b64encode(buffer.getvalue()).decode() if not self.is_img2img and self.api_key != "0000000000" and shared_laion: payload["shared"] = True if len(post_processing) > 0: payload["params"]["post_processing"] = post_processing if shared.state.skipped or shared.state.interrupted: return (None, None) try: session = requests.Session() id = session.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload) assert id.status_code == 202, "Status Code: {} (expected {})".format(id.status_code, 202) id = id.json() id = id["id"] shared.state.sampling_steps = 0 start = time.time() while True: if shared.state.skipped or shared.state.interrupted: return self.cancel_process_batch_horde(id) try: status = session.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1) status = status.json() elapsed = int(time.time() - start) shared.state.sampling_steps = elapsed + status["wait_time"] shared.state.sampling_step = elapsed if status["done"]: shared.state.sampling_steps = shared.state.sampling_step images = session.get("{}/v2/generate/status/{}".format(self.api_endpoint, id)) images = images.json() images = images["generations"] models = [image["model"] for image in images] images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images] images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images] images = [torch.from_numpy(image) for image in images] images = torch.stack(images).to(shared.device) return (images, models) elif status["faulted"]: raise StableHordeError("This request caused an internal server error and could not be completed.") elif not status["is_possible"]: raise StableHordeError("This request will not be able to be completed with the pool of workers currently available.") else: time.sleep(1) except requests.Timeout: time.sleep(1) except AssertionError: id = id.json() raise StableHordeError(id["message"]) def cancel_process_batch_horde(self, id): images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60) images = images.json() images = images["generations"] models = [image["model"] for image in images] images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images] images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images] images = [torch.from_numpy(image) for image in images] if len(images) > 0: images = torch.stack(images).to(shared.device) return (images, models) else: return (None, None)