Improve support of post processors, improve GUI

main
natanjunges 2022-12-30 00:19:35 -03:00
parent 2fe3eb8d65
commit 1102736caa
2 changed files with 68 additions and 21 deletions

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from modules import scripts, processing, shared, images, devices
from modules import scripts, processing, shared, images, devices, ui
import gradio
import requests
import time
@ -49,6 +49,7 @@ class Main(scripts.Script):
"DPM++ 2M Karras": "k_dpmpp_2m"
}
KARRAS = {"LMS Karras", "DPM2 Karras", "DPM2 a Karras", "DPM++ 2S a Karras", "DPM++ 2M Karras"}
POST_PROCESSINGS = {"CodeFormers (Restore faces)", "GFPGAN (Restore faces)", "RealESRGAN_x4plus (Upscale)"}
#settings tab
api_endpoint = "https://stablehorde.net/api"
api_key = "0000000000"
@ -62,8 +63,7 @@ class Main(scripts.Script):
def show(self, is_img2img):
return True
def ui(self, is_img2img):
nsfw = gradio.Checkbox(False, label="NSFW")
def load_models(self):
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
if models.status_code == 200:
@ -74,17 +74,62 @@ class Main(scripts.Script):
models = []
models.insert(0, "Random")
model = gradio.Dropdown(models, value="Random", label="Model")
seed_variation = gradio.Number(value=1, label="Seed variation", precision=0)
return [nsfw, model, seed_variation]
return models
def run(self, p, nsfw, model, seed_variation):
def ui(self, is_img2img):
nsfw = gradio.Checkbox(False, label="NSFW")
with gradio.Box():
with gradio.Row(elem_id="model_row"):
model = gradio.Dropdown(self.load_models(), value="Random", label="Model")
model.style(container=False)
update = gradio.Button(ui.refresh_symbol, elem_id="update_model")
with gradio.Box():
seed_variation = gradio.Number(value=1, label="Seed variation", precision=0)
seed_variation.style(container=False)
with gradio.Box():
with gradio.Row(elem_id="post_processing_row"):
post_processing_1 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #1")
post_processing_1.style(container=False)
post_processing_2 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #2", interactive=False)
post_processing_2.style(container=False)
post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False)
post_processing_3.style(container=False)
def update_click():
return gradio.update(choices=self.load_models(), value="Random")
def post_processing_1_change(value_1):
return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"))
def post_processing_2_change(value_1, value_2):
return gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1, value_2}), value="None", interactive=value_2 != "None")
update.click(fn=update_click, outputs=model)
post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3])
post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3)
return [nsfw, model, seed_variation, post_processing_1, post_processing_2, post_processing_3]
def run(self, p, nsfw, model, seed_variation, post_processing_1, post_processing_2, post_processing_3):
if model != "Random":
model = model.split("(")[0].rstrip()
return self.process_images(p, nsfw, model, seed_variation)
post_processing = []
def process_images(self, p, nsfw, model, seed_variation):
if post_processing_1 != "None":
post_processing.append(post_processing_1.split("(")[0].rstrip())
if post_processing_2 != "None":
post_processing.append(post_processing_2.split("(")[0].rstrip())
if post_processing_3 != "None":
post_processing.append(post_processing_3.split("(")[0].rstrip())
return self.process_images(p, nsfw, model, seed_variation, post_processing)
def process_images(self, p, nsfw, model, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()}
@ -93,7 +138,7 @@ class Main(scripts.Script):
for k, v in p.override_settings.items():
setattr(shared.opts, k, v)
res = self.process_images_inner(p, nsfw, model, seed_variation)
res = self.process_images_inner(p, nsfw, model, seed_variation, post_processing)
finally:
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@ -101,7 +146,7 @@ class Main(scripts.Script):
return res
def process_images_inner(self, p, nsfw, model, seed_variation):
def process_images_inner(self, p, nsfw, model, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
if type(p.prompt) == list:
@ -178,7 +223,7 @@ class Main(scripts.Script):
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
x_samples_ddim = self.process_batch_horde(p, nsfw, model, seed_variation, prompts[0], negative_prompts[0], seeds[0])
x_samples_ddim = self.process_batch_horde(p, nsfw, model, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
if x_samples_ddim is None:
del x_samples_ddim
@ -250,7 +295,7 @@ class Main(scripts.Script):
return res
def process_batch_horde(self, p, nsfw, model, seed_variation, prompt, negative_prompt, seed):
def process_batch_horde(self, p, nsfw, model, seed_variation, post_processing, prompt, negative_prompt, seed):
payload = {
"prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt,
"params": {
@ -283,14 +328,6 @@ class Main(scripts.Script):
#img2img/inpainting
post_processing = []
#upscale RealESRGAN_x4plus
if p.restore_faces:
#CodeFormers
post_processing.append("GFPGAN")
if len(post_processing) > 0:
payload["params"]["post_processing"] = post_processing

10
style.css Normal file
View File

@ -0,0 +1,10 @@
#update_model {
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}
#model_row {
gap: 0.5rem;
}