stable-diffusion-webui-stab.../scripts/main.py

468 lines
20 KiB
Python

# Stable Horde for Web UI, a Stable Horde client for AUTOMATIC1111's Stable Diffusion Web UI
# Copyright (C) 2022 Natan Junges <natanajunges@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from modules import scripts, processing, shared, images, devices, ui
import gradio
import requests
import time
import PIL.Image
import base64
import io
import os.path
import numpy
import itertools
import torch
import json
settings_file = os.path.join(scripts.basedir(), "settings.json")
class FakeModel:
sd_model_hash=""
class StableHordeError(Exception):
pass
class Main(scripts.Script):
TITLE = "Run on Stable Horde"
SAMPLERS = {
"LMS": "k_lms",
"LMS Karras": "k_lms",
"Heun": "k_heun",
"Euler": "k_euler",
"Euler a": "k_euler_a",
"DPM2": "k_dpm_2",
"DPM2 Karras": "k_dpm_2",
"DPM2 a": "k_dpm_2_a",
"DPM2 a Karras": "k_dpm_2_a",
"DPM fast": "k_dpm_fast",
"DPM adaptive": "k_dpm_adaptive",
"DPM++ 2S a": "k_dpmpp_2s_a",
"DPM++ 2S a Karras": "k_dpmpp_2s_a",
"DPM++ 2M": "k_dpmpp_2m",
"DPM++ 2M Karras": "k_dpmpp_2m"
}
KARRAS = {"LMS Karras", "DPM2 Karras", "DPM2 a Karras", "DPM++ 2S a Karras", "DPM++ 2M Karras"}
POST_PROCESSINGS = {"CodeFormers (Face restoration)", "GFPGAN (Face restoration)", "RealESRGAN_x4plus (Upscaling)"}
def title(self):
return self.TITLE
def show(self, is_img2img):
return True
def load_settings(self):
if os.path.exists(settings_file):
with open(settings_file) as file:
opts = json.load(file)
self.api_endpoint = opts["api_endpoint"]
self.api_key = opts["api_key"]
self.censor_nsfw = opts["censor_nsfw"]
self.trusted_workers = opts["trusted_workers"]
self.workers = opts["workers"]
else:
self.api_endpoint = "https://stablehorde.net/api"
self.api_key = "0000000000"
self.censor_nsfw = True
self.trusted_workers = True
self.workers = []
def load_models(self):
self.load_settings()
try:
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
models = models.json()
models.sort(key=lambda m: (-m["count"], m["name"]))
models = ["{} ({})".format(m["name"], m["count"]) for m in models]
except requests.ConnectionError:
models = []
models.insert(0, "Random")
self.models = models
def ui(self, is_img2img):
with gradio.Box():
with gradio.Row(elem_id="horde_model_row"):
self.load_models()
model = gradio.Dropdown(self.models, value="Random", label="Model")
model.style(container=False)
update = gradio.Button(ui.refresh_symbol, elem_id="horde_update_model")
with gradio.Box():
with gradio.Row():
nsfw = gradio.Checkbox(False, label="NSFW")
shared_laion = gradio.Checkbox(False, label="Share with LAION", interactive=not is_img2img)
seed_variation = gradio.Slider(minimum=1, maximum=1000, value=1, step=1, label="Seed variation", elem_id="horde_seed_variation")
with gradio.Box():
with gradio.Row():
post_processing_1 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #1")
post_processing_1.style(container=False)
post_processing_2 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #2", interactive=False)
post_processing_2.style(container=False)
post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False)
post_processing_3.style(container=False)
def update_click():
self.load_models()
return gradio.update(choices=self.models, value="Random")
def post_processing_1_change(value_1):
return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=False))
def post_processing_2_change(value_1, value_2):
return gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1, value_2}), value="None", interactive=value_2 != "None")
update.click(fn=update_click, outputs=model)
post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3])
post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3)
def model_infotext(d):
if "Model" in d and d["Model"] != "Random":
try:
return next(filter(lambda s: s.startswith("{} (".format(d["Model"])), self.models))
except StopIteration:
pass
return "Random"
def post_processing_n_infotext(n):
def post_processing_infotext(d):
if "Post processing {}".format(n) in d:
try:
return next(filter(lambda s: s.startswith("{} (".format(d["Post processing {}".format(n)])), self.POST_PROCESSINGS))
except StopIteration:
pass
return "None"
return post_processing_infotext
self.infotext_fields = [
(model, model_infotext),
(nsfw, "NSFW"),
(shared_laion, "Share with LAION"),
(seed_variation, "Seed variation"),
(post_processing_1, post_processing_n_infotext(1)),
(post_processing_2, post_processing_n_infotext(2)),
(post_processing_3, post_processing_n_infotext(3))
]
return [model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3]
def run(self, p, model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3):
if model != "Random":
model = model.split("(")[0].rstrip()
post_processing = []
if post_processing_1 != "None":
post_processing.append(post_processing_1.split("(")[0].rstrip())
if post_processing_2 != "None":
post_processing.append(post_processing_2.split("(")[0].rstrip())
if post_processing_3 != "None":
post_processing.append(post_processing_3.split("(")[0].rstrip())
return self.process_images(p, model, nsfw, shared_laion, int(seed_variation), post_processing)
def process_images(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()}
try:
for k, v in p.override_settings.items():
setattr(shared.opts, k, v)
p.extra_generation_params = {
"Model": model,
"NSFW": nsfw,
"Share with LAION": shared_laion,
"Seed variation": seed_variation,
"Post processing 1": (post_processing[0] if len(post_processing) >= 1 else None),
"Post processing 2": (post_processing[1] if len(post_processing) >= 2 else None),
"Post processing 3": (post_processing[2] if len(post_processing) >= 3 else None)
}
res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing)
finally:
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(shared.opts, k, v)
return res
def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
if type(p.prompt) == list:
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
devices.torch_gc()
seed = processing.get_fixed_seed(p.seed)
p.subseed = -1
p.subseed_strength = 0
p.seed_resize_from_h = 0
p.seed_resize_from_w = 0
if type(p.prompt) == list:
p.all_prompts = list(itertools.chain.from_iterable((p.batch_size * [shared.prompt_styles.apply_styles_to_prompt(p.prompt[x * p.batch_size], p.styles)] for x in range(p.n_iter))))
else:
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
if type(p.negative_prompt) == list:
p.all_negative_prompts = list(itertools.chain.from_iterable((p.batch_size * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt[x * p.batch_size], p.styles)] for x in range(p.n_iter))))
else:
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
if type(seed) == list:
p.all_seeds = list(itertools.chain.from_iterable(([seed[x * p.batch_size] + y * seed_variation for y in range(p.batch_size)] for x in range(p.n_iter))))
else:
p.all_seeds = [int(seed) + x * seed_variation for x in range(len(p.all_prompts))]
p.all_subseeds = [-1 for _ in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
old_model = shared.sd_model
shared.sd_model = FakeModel
ret = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, {}, iteration, position_in_batch)
shared.sd_model = old_model
return ret
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
old_model = shared.sd_model
shared.sd_model = FakeModel
processed = processing.Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
shared.sd_model = old_model
if p.scripts is not None:
p.scripts.process(p)
infotexts = []
output_images = []
with torch.no_grad():
if shared.state.job_count == -1:
shared.state.job_count = p.n_iter
for n in range(p.n_iter):
p.iteration = n
if shared.state.skipped:
shared.state.skipped = False
if shared.state.interrupted:
break
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if len(prompts) == 0:
break
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
if x_samples_ddim is None:
break
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * numpy.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(numpy.uint8)
image = PIL.Image.fromarray(x_sample)
p.extra_generation_params["Model"] = models[i]
if p.color_corrections is not None and i < len(p.color_corrections):
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_color_correction:
image_without_cc = processing.apply_overlay(image, p.paste_to, i, p.overlay_images)
images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = processing.apply_color_correction(p.color_corrections[i], image)
image = processing.apply_overlay(image, p.paste_to, i, p.overlay_images)
if shared.opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
if shared.opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
del x_samples_ddim
devices.torch_gc()
p.extra_generation_params["Model"] = model
shared.state.job_no += 1
shared.state.sampling_step = 0
shared.state.current_image_sampling_step = 0
p.color_corrections = None
index_of_first_image = 0
unwanted_grid_because_of_img_count = len(output_images) < 2 and shared.opts.grid_only_if_multiple
if (shared.opts.return_grid or shared.opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
grid = images.image_grid(output_images, p.batch_size)
if shared.opts.return_grid:
text = infotext()
infotexts.insert(0, text)
if shared.opts.enable_pnginfo:
grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1
if shared.opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], shared.opts.grid_format, info=infotext(), short_filename=not shared.opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
old_model = shared.sd_model
shared.sd_model = FakeModel
res = processing.Processed(p, output_images, p.all_seeds[0], infotext(), subseed=-1, index_of_first_image=index_of_first_image, infotexts=infotexts)
shared.sd_model = old_model
if p.scripts is not None:
p.scripts.postprocess(p, res)
return res
def process_batch_horde(self, p, model, nsfw, shared_laion, seed_variation, post_processing, prompt, negative_prompt, seed):
payload = {
"prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt,
"params": {
"sampler_name": self.SAMPLERS.get(p.sampler_name, "k_euler_a"),
"cfg_scale": p.cfg_scale,
"denoising_strength": p.denoising_strength if p.denoising_strength is not None else 0,
"seed": str(seed),
"height": p.height,
"width": p.width,
"seed_variation": seed_variation,
"karras": p.sampler_name in self.KARRAS,
"steps": p.steps,
"n": p.batch_size
}
}
self.load_settings()
if nsfw:
payload["nsfw"] = True
elif self.censor_nsfw:
payload["censor_nsfw"] = True
if not self.trusted_workers:
payload["trusted_workers"] = False
if len(self.workers) > 0:
payload["workers"] = self.workers
if model != "Random":
payload["models"] = [model]
if self.is_img2img:
buffer = io.BytesIO()
p.init_images[0].save(buffer, format="WEBP")
payload["source_image"] = base64.b64encode(buffer.getvalue()).decode()
if p.image_mask is None:
payload["source_processing"] = "img2img"
else:
payload["source_processing"] = "inpainting"
buffer = io.BytesIO()
p.image_mask.save(buffer, format="WEBP")
payload["source_mask"] = base64.b64encode(buffer.getvalue()).decode()
if not self.is_img2img and self.api_key != "0000000000" and shared_laion:
payload["shared"] = True
if len(post_processing) > 0:
payload["params"]["post_processing"] = post_processing
if shared.state.skipped or shared.state.interrupted:
return (None, None)
try:
id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
assert id.status_code == 202, "Status Code: {} (expected {})".format(id.status_code, 202)
id = id.json()
id = id["id"]
shared.state.sampling_steps = p.batch_size
while True:
if shared.state.skipped or shared.state.interrupted:
return self.cancel_process_batch_horde(id)
try:
status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
status = status.json()
shared.state.sampling_step = status["finished"]
if status["done"]:
shared.state.sampling_step = shared.state.sampling_steps
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
images = torch.stack(images).to(shared.device)
return (images, models)
elif status["faulted"]:
raise StableHordeError("This request caused an internal server error and could not be completed.")
elif not status["is_possible"]:
raise StableHordeError("This request will not be able to be completed with the pool of workers currently available.")
else:
time.sleep(1)
except requests.Timeout:
time.sleep(1)
except AssertionError as e:
id = id.json()
raise StableHordeError(id["message"])
def cancel_process_batch_horde(self, id):
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
if len(images) > 0:
images = torch.stack(images).to(shared.device)
return (images, models)
else:
return (None, None)