Save and load generation parameters as infotext fields
parent
cacbe7a731
commit
f4291962b0
|
|
@ -90,12 +90,13 @@ class Main(scripts.Script):
|
|||
models = []
|
||||
|
||||
models.insert(0, "Random")
|
||||
return models
|
||||
self.models = models
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with gradio.Box():
|
||||
with gradio.Row(elem_id="horde_model_row"):
|
||||
model = gradio.Dropdown(self.load_models(), value="Random", label="Model")
|
||||
self.load_models()
|
||||
model = gradio.Dropdown(self.models, value="Random", label="Model")
|
||||
model.style(container=False)
|
||||
update = gradio.Button(ui.refresh_symbol, elem_id="horde_update_model")
|
||||
|
||||
|
|
@ -114,8 +115,19 @@ class Main(scripts.Script):
|
|||
post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False)
|
||||
post_processing_3.style(container=False)
|
||||
|
||||
self.infotext_fields = [
|
||||
(model, lambda d: next(filter(lambda s: s == d["Model"] or s.startswith("{} (".format(d["Model"])), self.models))),
|
||||
(nsfw, "NSFW"),
|
||||
(shared_laion, "Share with LAION"),
|
||||
(seed_variation, "Seed variation"),
|
||||
(post_processing_1, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #1"])), self.models)) if "Post processing #1" in d else "None"),
|
||||
(post_processing_2, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #2"])), self.models)) if "Post processing #2" in d else "None"),
|
||||
(post_processing_3, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #3"])), self.models)) if "Post processing #3" in d else "None")
|
||||
]
|
||||
|
||||
def update_click():
|
||||
return gradio.update(choices=self.load_models(), value="Random")
|
||||
self.load_models()
|
||||
return gradio.update(choices=self.models, value="Random")
|
||||
|
||||
def post_processing_1_change(value_1):
|
||||
return (gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=value_1 != "None"), gradio.update(choices=["None"] + sorted(self.POST_PROCESSINGS - {value_1}), value="None", interactive=False))
|
||||
|
|
@ -154,6 +166,16 @@ class Main(scripts.Script):
|
|||
for k, v in p.override_settings.items():
|
||||
setattr(shared.opts, k, v)
|
||||
|
||||
p.extra_generation_params = {
|
||||
"Model": model,
|
||||
"NSFW": nsfw,
|
||||
"Share with LAION": shared_laion,
|
||||
"Seed variation": seed_variation,
|
||||
"Post processing #1": (post_processing[0] if len(post_processing) >= 1 else None),
|
||||
"Post processing #2": (post_processing[1] if len(post_processing) >= 2 else None),
|
||||
"Post processing #3": (post_processing[2] if len(post_processing) >= 3 else None)
|
||||
}
|
||||
|
||||
res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing)
|
||||
finally:
|
||||
if p.override_settings_restore_afterwards:
|
||||
|
|
@ -241,7 +263,7 @@ class Main(scripts.Script):
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
x_samples_ddim = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
|
||||
x_samples_ddim, models = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
|
||||
|
||||
if x_samples_ddim is None:
|
||||
del x_samples_ddim
|
||||
|
|
@ -255,6 +277,7 @@ class Main(scripts.Script):
|
|||
x_sample = 255. * numpy.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(numpy.uint8)
|
||||
image = PIL.Image.fromarray(x_sample)
|
||||
p.extra_generation_params["Model"] = models[i]
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_color_correction:
|
||||
|
|
@ -278,6 +301,7 @@ class Main(scripts.Script):
|
|||
|
||||
del x_samples_ddim
|
||||
devices.torch_gc()
|
||||
p.extra_generation_params["Model"] = model
|
||||
shared.state.job_no += 1
|
||||
shared.state.sampling_step = 0
|
||||
shared.state.current_image_sampling_step = 0
|
||||
|
|
@ -397,11 +421,12 @@ class Main(scripts.Script):
|
|||
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
models = [image["model"] for image in images]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
|
||||
images = [torch.from_numpy(image) for image in images]
|
||||
images = torch.stack(images).to(shared.device)
|
||||
return images
|
||||
return (images, models)
|
||||
except (requests.ConnectionError, AssertionError) as e:
|
||||
print(e)
|
||||
|
||||
|
|
@ -444,13 +469,14 @@ class Main(scripts.Script):
|
|||
assert images.status_code == 200, "Status Code: {} (expected {})".format(images.status_code, 200)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
models = [image["model"] for image in images]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
images = [numpy.moveaxis(numpy.array(image).astype(numpy.float32) / 255.0, 2, 0) for image in images]
|
||||
images = [torch.from_numpy(image) for image in images]
|
||||
|
||||
if len(images) > 0:
|
||||
images = torch.stack(images).to(shared.device)
|
||||
return images
|
||||
return (images, models)
|
||||
except requests.Timeout:
|
||||
return
|
||||
except (requests.ConnectionError, AssertionError) as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue