Fix name conflict of "shared", improve error messages

main
natanjunges 2023-01-09 15:13:26 -03:00
parent dfffa1f6ca
commit d67e122f72
1 changed files with 17 additions and 15 deletions

View File

@ -82,7 +82,7 @@ class Main(scripts.Script):
try:
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
assert models.status_code == 200
assert models.status_code == 200, "Status Code: {}".format(models.status_code)
models = models.json()
models.sort(key=lambda m: (-m["count"], m["name"]))
models = ["{} ({})".format(m["name"], m["count"]) for m in models]
@ -102,7 +102,7 @@ class Main(scripts.Script):
with gradio.Box():
with gradio.Row():
nsfw = gradio.Checkbox(False, label="NSFW")
shared = gradio.Checkbox(False, label="Share with LAION")
shared_laion = gradio.Checkbox(False, label="Share with LAION")
seed_variation = gradio.Slider(minimum=1, maximum=1000, value=1, step=1, label="Seed variation", elem_id="horde_seed_variation")
with gradio.Box():
@ -126,9 +126,9 @@ class Main(scripts.Script):
update.click(fn=update_click, outputs=model)
post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3])
post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3)
return [model, nsfw, shared, seed_variation, post_processing_1, post_processing_2, post_processing_3]
return [model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3]
def run(self, p, model, nsfw, shared, seed_variation, post_processing_1, post_processing_2, post_processing_3):
def run(self, p, model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3):
if model != "Random":
model = model.split("(")[0].rstrip()
@ -143,9 +143,9 @@ class Main(scripts.Script):
if post_processing_3 != "None":
post_processing.append(post_processing_3.split("(")[0].rstrip())
return self.process_images(p, model, nsfw, shared, int(seed_variation), post_processing)
return self.process_images(p, model, nsfw, shared_laion, int(seed_variation), post_processing)
def process_images(self, p, model, nsfw, shared, seed_variation, post_processing):
def process_images(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()}
@ -154,7 +154,7 @@ class Main(scripts.Script):
for k, v in p.override_settings.items():
setattr(shared.opts, k, v)
res = self.process_images_inner(p, model, nsfw, shared, seed_variation, post_processing)
res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing)
finally:
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@ -162,7 +162,7 @@ class Main(scripts.Script):
return res
def process_images_inner(self, p, model, nsfw, shared, seed_variation, post_processing):
def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
if type(p.prompt) == list:
@ -219,6 +219,8 @@ class Main(scripts.Script):
shared.state.job_count = p.n_iter
for n in range(p.n_iter):
p.iteration = n
if shared.state.skipped:
shared.state.skipped = False
@ -239,7 +241,7 @@ class Main(scripts.Script):
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
x_samples_ddim = self.process_batch_horde(p, model, nsfw, shared, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
x_samples_ddim = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
if x_samples_ddim is None:
del x_samples_ddim
@ -311,7 +313,7 @@ class Main(scripts.Script):
return res
def process_batch_horde(self, p, model, nsfw, shared, seed_variation, post_processing, prompt, negative_prompt, seed):
def process_batch_horde(self, p, model, nsfw, shared_laion, seed_variation, post_processing, prompt, negative_prompt, seed):
payload = {
"prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt,
"params": {
@ -356,7 +358,7 @@ class Main(scripts.Script):
p.image_mask.save(buffer, format="WEBP")
payload["source_mask"] = base64.b64encode(buffer.getvalue()).decode()
if shared:
if shared_laion:
payload["shared"] = True
if len(post_processing) > 0:
@ -369,7 +371,7 @@ class Main(scripts.Script):
try:
id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
assert id.status_code == 202
assert id.status_code == 202, "Status Code: {}".format(id.status_code)
id = id.json()
id = id["id"]
shared.state.sampling_steps = p.batch_size
@ -382,7 +384,7 @@ class Main(scripts.Script):
try:
status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
assert status.status_code == 200
assert status.status_code == 200, "Status Code: {}".format(status.status_code)
status = status.json()
shared.state.sampling_step = status["finished"]
@ -392,7 +394,7 @@ class Main(scripts.Script):
try:
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
assert images.status_code == 200
assert images.status_code == 200, "Status Code: {}".format(images.status_code)
images = images.json()
images = images["generations"]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
@ -439,7 +441,7 @@ class Main(scripts.Script):
try:
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
assert images.status_code == 200
assert images.status_code == 200, "Status Code: {}".format(images.status_code)
images = images.json()
images = images["generations"]
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]