diff --git a/scripts/main.py b/scripts/main.py index d1a8239..69b9ba8 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -115,16 +115,6 @@ class Main(scripts.Script): post_processing_3 = gradio.Dropdown(["None"] + sorted(self.POST_PROCESSINGS), value="None", label="Post processing #3", interactive=False) post_processing_3.style(container=False) - self.infotext_fields = [ - (model, lambda d: next(filter(lambda s: s == d["Model"] or s.startswith("{} (".format(d["Model"])), self.models))), - (nsfw, "NSFW"), - (shared_laion, "Share with LAION"), - (seed_variation, "Seed variation"), - (post_processing_1, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #1"])), self.models)) if "Post processing #1" in d else "None"), - (post_processing_2, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #2"])), self.models)) if "Post processing #2" in d else "None"), - (post_processing_3, lambda d: next(filter(lambda s: s.startswith("{} (".format(d["Post processing #3"])), self.models)) if "Post processing #3" in d else "None") - ] - def update_click(): self.load_models() return gradio.update(choices=self.models, value="Random") @@ -138,6 +128,37 @@ class Main(scripts.Script): update.click(fn=update_click, outputs=model) post_processing_1.change(fn=post_processing_1_change, inputs=post_processing_1, outputs=[post_processing_2, post_processing_3]) post_processing_2.change(fn=post_processing_2_change, inputs=[post_processing_1, post_processing_2], outputs=post_processing_3) + + def model_infotext(d): + if "Model" in d and d["Model"] != "Random": + try: + return next(filter(lambda s: s.startswith("{} (".format(d["Model"])), self.models)) + except StopIteration: + pass + + return "Random" + + def post_processing_n_infotext(n): + def post_processing_infotext(d): + if "Post processing {}".format(n) in d: + try: + return next(filter(lambda s: s.startswith("{} (".format(d["Post processing {}".format(n)])), self.POST_PROCESSINGS)) + except StopIteration: + pass + + return "None" + + return post_processing_infotext + + self.infotext_fields = [ + (model, model_infotext), + (nsfw, "NSFW"), + (shared_laion, "Share with LAION"), + (seed_variation, "Seed variation"), + (post_processing_1, post_processing_n_infotext(1)), + (post_processing_2, post_processing_n_infotext(2)), + (post_processing_3, post_processing_n_infotext(3)) + ] return [model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3] def run(self, p, model, nsfw, shared_laion, seed_variation, post_processing_1, post_processing_2, post_processing_3): @@ -171,9 +192,9 @@ class Main(scripts.Script): "NSFW": nsfw, "Share with LAION": shared_laion, "Seed variation": seed_variation, - "Post processing #1": (post_processing[0] if len(post_processing) >= 1 else None), - "Post processing #2": (post_processing[1] if len(post_processing) >= 2 else None), - "Post processing #3": (post_processing[2] if len(post_processing) >= 3 else None) + "Post processing 1": (post_processing[0] if len(post_processing) >= 1 else None), + "Post processing 2": (post_processing[1] if len(post_processing) >= 2 else None), + "Post processing 3": (post_processing[2] if len(post_processing) >= 3 else None) } res = self.process_images_inner(p, model, nsfw, shared_laion, seed_variation, post_processing)