Save and load generation parameters as infotext fields

main
natanjunges 2023-01-09 18:19:40 -03:00
parent cacbe7a731
commit f4291962b0
1 changed files with 32 additions and 6 deletions

View File

@ -90,12 +90,13 @@ class Main(scripts.Script):
models = []
models.insert(0, "Random")
return models
self.models = models
def ui(self, is_img2img):
with gradio.Box():
with gradio.Row(elem_id="horde_model_row"):
model = gradio.Dropdown(self.load_models(), value="Random", label="Model")
self.load_models()
model = gradio.Dropdown(self.models, value="Random", label="Model")
model.style(container=False)
update = gradio.Button(ui.refresh_symbol, elem_id="horde_update_model")
@ -114,8 +115,19 @@ class Main(scripts.Script):
post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False)
post_processing_3.style(container=False)
self.infotext_fields = [
(model, lambda d: next(filter(lambda s: s == d["Model"] or s.startswith("{} (".format(d["Model"])), self.models))),
(nsfw, "NSFW"),
(shared_laion, "Share with LAION"),
(seed_variation, "Seed variation"),
(post_processing_1, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #1"])), self.models)) if "Post processing #1" in d else "None"),
(post_processing_2, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #2"])), self.models)) if "Post processing #2" in d else "None"),
(post_processing_3, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #3"])), self.models)) if "Post processing #3" in d else "None")
]
def update_click():
return gradio.update(choices=self.load_models(), value="Random")
self.load_models()
return gradio.update(choices=self.models, value="Random")
def post_processing_1_change(value_1):
return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=False))
@ -154,6 +166,16 @@ class Main(scripts.Script):
for k, v in p.override_settings.items():
setattr(shared.opts, k, v)
p.extra_generation_params = {
"Model": model,
"NSFW": nsfw,
"Share with LAION": shared_laion,
"Seed variation": seed_variation,
"Post processing #1": (post_processing[0] if len(post_processing) >= 1 else None),
"Post processing #2": (post_processing[1] if len(post_processing) >= 2 else None),
"Post processing #3": (post_processing[2] if len(post_processing) >= 3 else None)
}
res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing)
finally:
if p.override_settings_restore_afterwards:
@ -241,7 +263,7 @@ class Main(scripts.Script):
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
x_samples_ddim = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
if x_samples_ddim is None:
del x_samples_ddim
@ -255,6 +277,7 @@ class Main(scripts.Script):
x_sample = 255. * numpy.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(numpy.uint8)
image = PIL.Image.fromarray(x_sample)
p.extra_generation_params["Model"] = models[i]
if p.color_corrections is not None and i < len(p.color_corrections):
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_color_correction:
@ -278,6 +301,7 @@ class Main(scripts.Script):
del x_samples_ddim
devices.torch_gc()
p.extra_generation_params["Model"] = model
shared.state.job_no += 1
shared.state.sampling_step = 0
shared.state.current_image_sampling_step = 0
@ -397,11 +421,12 @@ class Main(scripts.Script):
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
images = torch.stack(images).to(shared.device)
return images
return (images, models)
except (requests.ConnectionError, AssertionError) as e:
print(e)
@ -444,13 +469,14 @@ class Main(scripts.Script):
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
images = images.json()
images = images["generations"]
models = [image["model"] for image in images]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
images = [torch.from_numpy(image) for image in images]
if len(images) > 0:
images = torch.stack(images).to(shared.device)
return images
return (images, models)
except requests.Timeout:
return
except (requests.ConnectionError, AssertionError) as e: