diff --git a/sd_dynamic_prompts/dynamic_prompting.py b/sd_dynamic_prompts/dynamic_prompting.py index 43281f3..901b529 100644 --- a/sd_dynamic_prompts/dynamic_prompting.py +++ b/sd_dynamic_prompts/dynamic_prompting.py @@ -479,12 +479,12 @@ class Script(scripts.Script): all_seeds = p.all_seeds all_prompts, all_negative_prompts = generate_prompts( - generator, - negative_generator, - original_prompt, - original_negative_prompt, - num_images, - all_seeds, + prompt_generator=generator, + negative_prompt_generator=negative_generator, + prompt=original_prompt, + negative_prompt=original_negative_prompt, + num_prompts=num_images, + seeds=all_seeds, ) except GeneratorException as e: diff --git a/sd_dynamic_prompts/generator_builder.py b/sd_dynamic_prompts/generator_builder.py index 62b81be..ebf3468 100644 --- a/sd_dynamic_prompts/generator_builder.py +++ b/sd_dynamic_prompts/generator_builder.py @@ -198,20 +198,17 @@ class GeneratorBuilder: parser_config=self._parser_config, ignore_whitespace=self._ignore_whitespace, ) - prompt_generator = BatchedCombinatorialPromptGenerator( + return BatchedCombinatorialPromptGenerator( prompt_generator, batches=self._combinatorial_batches, ) - else: - prompt_generator = RandomPromptGenerator( - self._wildcard_manager, - seed=self._seed, - parser_config=self._parser_config, - unlink_seed_from_prompt=self._unlink_seed_from_prompt, - ignore_whitespace=self._ignore_whitespace, - ) - - return prompt_generator + return RandomPromptGenerator( + self._wildcard_manager, + seed=self._seed, + parser_config=self._parser_config, + unlink_seed_from_prompt=self._unlink_seed_from_prompt, + ignore_whitespace=self._ignore_whitespace, + ) def create_jinja_generator(self, p) -> PromptGenerator: original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt diff --git a/sd_dynamic_prompts/helpers.py b/sd_dynamic_prompts/helpers.py index 910fafc..4aef51b 100644 --- a/sd_dynamic_prompts/helpers.py +++ b/sd_dynamic_prompts/helpers.py @@ -17,7 +17,7 @@ def get_seeds( use_fixed_seed, is_combinatorial=False, combinatorial_batches=1, -): +) -> tuple[list[int], list[int]]: if p.subseed_strength != 0: seed = int(p.all_seeds[0]) subseed = int(p.all_subseeds[0]) @@ -74,7 +74,7 @@ def generate_prompts( prompt: str, negative_prompt: str | None, num_prompts: int, - seeds: list[int], + seeds: list[int] | None, ) -> tuple[list[str], list[str]]: """ Generate positive and negative prompts. diff --git a/sd_dynamic_prompts/pnginfo_saver.py b/sd_dynamic_prompts/pnginfo_saver.py index a4b03c6..16a03c1 100644 --- a/sd_dynamic_prompts/pnginfo_saver.py +++ b/sd_dynamic_prompts/pnginfo_saver.py @@ -42,7 +42,7 @@ class PngInfoSaver: return parameters - def strip_template_info(self, parameters: dict[str, Any]) -> str: + def strip_template_info(self, parameters: dict[str, Any]) -> dict[str, Any]: if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]: parameters["Prompt"] = ( parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip()