Fix name conflict of "shared", improve error messages
parent
dfffa1f6ca
commit
d67e122f72
|
|
@ -82,7 +82,7 @@ class Main(scripts.Script):
|
|||
|
||||
try:
|
||||
models = requests.get("{}/v2/status/models".format(self.api_endpoint))
|
||||
assert models.status_code == 200
|
||||
assert models.status_code == 200, "Status Code: {}".format(models.status_code)
|
||||
models = models.json()
|
||||
models.sort(key=lambda m: (-m["count"], m["name"]))
|
||||
models = ["{} ({})".format(m["name"], m["count"]) for m in models]
|
||||
|
|
@ -102,7 +102,7 @@ class Main(scripts.Script):
|
|||
with gradio.Box():
|
||||
with gradio.Row():
|
||||
nsfw = gradio.Checkbox(False, label="NSFW")
|
||||
shared = gradio.Checkbox(False, label="Share with LAION")
|
||||
shared_laion = gradio.Checkbox(False, label="Share with LAION")
|
||||
seed_variation = gradio.Slider(minimum=1, maximum=1000, value=1, step=1, label="Seed variation", elem_id="horde_seed_variation")
|
||||
|
||||
with gradio.Box():
|
||||
|
|
@ -126,9 +126,9 @@ class Main(scripts.Script):
|
|||
update.click(fn=update_click, outputs=model)
|
||||
post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3])
|
||||
post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3)
|
||||
return [model, nsfw, shared, seed_variation, post_processing_1, post_processing_2, post_processing_3]
|
||||
return [model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3]
|
||||
|
||||
def run(self, p, model, nsfw, shared, seed_variation, post_processing_1, post_processing_2, post_processing_3):
|
||||
def run(self, p, model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3):
|
||||
if model != "Random":
|
||||
model = model.split("(")[0].rstrip()
|
||||
|
||||
|
|
@ -143,9 +143,9 @@ class Main(scripts.Script):
|
|||
if post_processing_3 != "None":
|
||||
post_processing.append(post_processing_3.split("(")[0].rstrip())
|
||||
|
||||
return self.process_images(p, model, nsfw, shared, int(seed_variation), post_processing)
|
||||
return self.process_images(p, model, nsfw, shared_laion, int(seed_variation), post_processing)
|
||||
|
||||
def process_images(self, p, model, nsfw, shared, seed_variation, post_processing):
|
||||
def process_images(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
|
||||
# Copyright (C) 2022 AUTOMATIC1111
|
||||
|
||||
stored_opts = {k: shared.opts.data[k] for k in p.override_settings.keys()}
|
||||
|
|
@ -154,7 +154,7 @@ class Main(scripts.Script):
|
|||
for k, v in p.override_settings.items():
|
||||
setattr(shared.opts, k, v)
|
||||
|
||||
res = self.process_images_inner(p, model, nsfw, shared, seed_variation, post_processing)
|
||||
res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing)
|
||||
finally:
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
|
|
@ -162,7 +162,7 @@ class Main(scripts.Script):
|
|||
|
||||
return res
|
||||
|
||||
def process_images_inner(self, p, model, nsfw, shared, seed_variation, post_processing):
|
||||
def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
|
||||
# Copyright (C) 2022 AUTOMATIC1111
|
||||
|
||||
if type(p.prompt) == list:
|
||||
|
|
@ -219,6 +219,8 @@ class Main(scripts.Script):
|
|||
shared.state.job_count = p.n_iter
|
||||
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
if shared.state.skipped:
|
||||
shared.state.skipped = False
|
||||
|
||||
|
|
@ -239,7 +241,7 @@ class Main(scripts.Script):
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
x_samples_ddim = self.process_batch_horde(p, model, nsfw, shared, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
|
||||
x_samples_ddim = self.process_batch_horde(p, model, nsfw, shared_laion, seed_variation, post_processing, prompts[0], negative_prompts[0], seeds[0])
|
||||
|
||||
if x_samples_ddim is None:
|
||||
del x_samples_ddim
|
||||
|
|
@ -311,7 +313,7 @@ class Main(scripts.Script):
|
|||
|
||||
return res
|
||||
|
||||
def process_batch_horde(self, p, model, nsfw, shared, seed_variation, post_processing, prompt, negative_prompt, seed):
|
||||
def process_batch_horde(self, p, model, nsfw, shared_laion, seed_variation, post_processing, prompt, negative_prompt, seed):
|
||||
payload = {
|
||||
"prompt": "{} ### {}".format(prompt, negative_prompt) if len(negative_prompt) > 0 else prompt,
|
||||
"params": {
|
||||
|
|
@ -356,7 +358,7 @@ class Main(scripts.Script):
|
|||
p.image_mask.save(buffer, format="WEBP")
|
||||
payload["source_mask"] = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
if shared:
|
||||
if shared_laion:
|
||||
payload["shared"] = True
|
||||
|
||||
if len(post_processing) > 0:
|
||||
|
|
@ -369,7 +371,7 @@ class Main(scripts.Script):
|
|||
|
||||
try:
|
||||
id = requests.post("{}/v2/generate/async".format(self.api_endpoint), headers={"apikey": self.api_key}, json=payload)
|
||||
assert id.status_code == 202
|
||||
assert id.status_code == 202, "Status Code: {}".format(id.status_code)
|
||||
id = id.json()
|
||||
id = id["id"]
|
||||
shared.state.sampling_steps = p.batch_size
|
||||
|
|
@ -382,7 +384,7 @@ class Main(scripts.Script):
|
|||
|
||||
try:
|
||||
status = requests.get("{}/v2/generate/check/{}".format(self.api_endpoint, id), timeout=1)
|
||||
assert status.status_code == 200
|
||||
assert status.status_code == 200, "Status Code: {}".format(status.status_code)
|
||||
status = status.json()
|
||||
shared.state.sampling_step = status["finished"]
|
||||
|
||||
|
|
@ -392,7 +394,7 @@ class Main(scripts.Script):
|
|||
|
||||
try:
|
||||
images = requests.get("{}/v2/generate/status/{}".format(self.api_endpoint, id))
|
||||
assert images.status_code == 200
|
||||
assert images.status_code == 200, "Status Code: {}".format(images.status_code)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
|
|
@ -439,7 +441,7 @@ class Main(scripts.Script):
|
|||
|
||||
try:
|
||||
images = requests.delete("{}/v2/generate/status/{}".format(self.api_endpoint, id), timeout=60)
|
||||
assert images.status_code == 200
|
||||
assert images.status_code == 200, "Status Code: {}".format(images.status_code)
|
||||
images = images.json()
|
||||
images = images["generations"]
|
||||
images = [PIL.Image.open(io.BytesIO(base64.b64decode(image["img"]))) for image in images]
|
||||
|
|
|
|||
Loading…
Reference in New Issue