Improve support of post processors, improve GUI
parent
2fe3eb8d65
commit
1102736caa
|
|
@ -14,7 +14,7 @@
|
|||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
from modules import scripts, processing, shared, images, devices
|
||||
from modules import scripts, processing, shared, images, devices, ui
|
||||
import gradio
|
||||
import requests
|
||||
import time
|
||||
|
|
@ -49,6 +49,7 @@ class Main(scripts.Script):
|
|||
"DPM++ 2M Karras": "k_dpmpp_2m"
|
||||
}
|
||||
KARRAS = {"LMS Karras", "DPM2 Karras", "DPM2 a Karras", "DPM++ 2S a Karras", "DPM++ 2M Karras"}
|
||||
POST_PROCESSINGS = {"CodeFormers (Restore faces)", "GFPGAN (Restore faces)", "RealESRGAN_x4plus (Upscale)"}
|
||||
#settings tab
|
||||
api_endpoint = "https://stablehorde.net/api"
|
||||
api_key = "0000000000"
|
||||
|
|
@ -62,8 +63,7 @@ class Main(scripts.Script):
|
|||
def show(self, is_img2img):
|
||||
return True
|
||||
|
||||
def ui(self, is_img2img):
|
||||
nsfw = gradio.Checkbox(False, label="NSFW")
|
||||
def load_models(self):
|
||||
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
|
||||
|
||||
if models.status_code == 200:
|
||||
|
|
@ -74,17 +74,62 @@ class Main(scripts.Script):
|
|||
models = []
|
||||
|
||||
models.insert(0, "Random")
|
||||
model = gradio.Dropdown(models, value="Random", label="Model")
|
||||
seed_variation = gradio.Number(value=1, label="Seed variation", precision=0)
|
||||
return [nsfw, model, seed_variation]
|
||||
return models
|
||||
|
||||
def run(self, p, nsfw, model, seed_variation):
|
||||
def ui(self, is_img2img):
|
||||
nsfw = gradio.Checkbox(False, label="NSFW")
|
||||
|
||||
with gradio.Box():
|
||||
with gradio.Row(elem_id="model_row"):
|
||||
model = gradio.Dropdown(self.load_models(), value="Random", label="Model")
|
||||
model.style(container=False)
|
||||
update = gradio.Button(ui.refresh_symbol, elem_id="update_model")
|
||||
|
||||
with gradio.Box():
|
||||
seed_variation = gradio.Number(value=1, label="Seed variation", precision=0)
|
||||
seed_variation.style(container=False)
|
||||
|
||||
with gradio.Box():
|
||||
with gradio.Row(elem_id="post_processing_row"):
|
||||
post_processing_1 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #1")
|
||||
post_processing_1.style(container=False)
|
||||
post_processing_2 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #2", interactive=False)
|
||||
post_processing_2.style(container=False)
|
||||
post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False)
|
||||
post_processing_3.style(container=False)
|
||||
|
||||
def update_click():
|
||||
return gradio.update(choices=self.load_models(), value="Random")
|
||||
|
||||
def post_processing_1_change(value_1):
|
||||
return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"))
|
||||
|
||||
def post_processing_2_change(value_1, value_2):
|
||||
return gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1, value_2}), value="None", interactive=value_2 != "None")
|
||||
|
||||
update.click(fn=update_click, outputs=model)
|
||||
post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3])
|
||||
post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3)
|
||||
return [nsfw, model, seed_variation, post_processing_1, post_processing_2, post_processing_3]
|
||||
|
||||
def run(self, p, nsfw, model, seed_variation, post_processing_1, post_processing_2, post_processing_3):
|
||||
if model != "Random":
|
||||
model = model.split("(")[0].rstrip()
|
||||
|
||||
return self.process_images(p, nsfw, model, seed_variation)
|
||||
post_processing = []
|
||||
|
||||
def process_images(self, p, nsfw, model, seed_variation):
|
||||
if post_processing_1 != "None":
|
||||
post_processing.append(post_processing_1.split("(")[0].rstrip())
|
||||
|
||||
if post_processing_2 != "None":
|
||||
post_processing.append(post_processing_2.split("(")[0].rstrip())
|
||||
|
||||
if post_processing_3 != "None":
|
||||
post_processing.append(post_processing_3.split("(")[0].rstrip())
|
||||
|
||||
return self.process_images(p, nsfw, model, seed_variation, post_processing)
|
||||
|
||||
def process_images(self, p, nsfw, model, seed_variation, post_processing):
|
||||
# Copyright (C) 2022 AUTOMATIC1111
|
||||
|
||||
stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()}
|
||||
|
|
@ -93,7 +138,7 @@ class Main(scripts.Script):
|
|||
for k, v in p.override_settings.items():
|
||||
setattr(shared.opts, k, v)
|
||||
|
||||
res = self.process_images_inner(p, nsfw, model, seed_variation)
|
||||
res = self.process_images_inner(p, nsfw, model, seed_variation, post_processing)
|
||||
finally:
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
|
|
@ -101,7 +146,7 @@ class Main(scripts.Script):
|
|||
|
||||
return res
|
||||
|
||||
def process_images_inner(self, p, nsfw, model, seed_variation):
|
||||
def process_images_inner(self, p, nsfw, model, seed_variation, post_processing):
|
||||
# Copyright (C) 2022 AUTOMATIC1111
|
||||
|
||||
if type(p.prompt) == list:
|
||||
|
|
@ -178,7 +223,7 @@ class Main(scripts.Script):
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
x_samples_ddim = self.process_batch_horde(p, nsfw, model, seed_variation, prompts[0], negative_prompts[0], seeds[0])
|
||||
x_samples_ddim = self.process_batch_horde(p, nsfw, model, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
|
||||
|
||||
if x_samples_ddim is None:
|
||||
del x_samples_ddim
|
||||
|
|
@ -250,7 +295,7 @@ class Main(scripts.Script):
|
|||
|
||||
return res
|
||||
|
||||
def process_batch_horde(self, p, nsfw, model, seed_variation, prompt, negative_prompt, seed):
|
||||
def process_batch_horde(self, p, nsfw, model, seed_variation, post_processing, prompt, negative_prompt, seed):
|
||||
payload = {
|
||||
"prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt,
|
||||
"params": {
|
||||
|
|
@ -283,14 +328,6 @@ class Main(scripts.Script):
|
|||
|
||||
#img2img/inpainting
|
||||
|
||||
post_processing = []
|
||||
|
||||
#upscale RealESRGAN_x4plus
|
||||
|
||||
if p.restore_faces:
|
||||
#CodeFormers
|
||||
post_processing.append("GFPGAN")
|
||||
|
||||
if len(post_processing) > 0:
|
||||
payload["params"]["post_processing"] = post_processing
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue