Make Mypy happier

pull/597/head
Aarni Koskela 2023-08-16 13:31:16 +03:00
parent 27aae01a3b
commit 52cf08f55d
4 changed files with 17 additions and 20 deletions

View File

@ -479,12 +479,12 @@ class Script(scripts.Script):
all_seeds = p.all_seeds
all_prompts, all_negative_prompts = generate_prompts(
generator,
negative_generator,
original_prompt,
original_negative_prompt,
num_images,
all_seeds,
prompt_generator=generator,
negative_prompt_generator=negative_generator,
prompt=original_prompt,
negative_prompt=original_negative_prompt,
num_prompts=num_images,
seeds=all_seeds,
)
except GeneratorException as e:

View File

@ -198,20 +198,17 @@ class GeneratorBuilder:
parser_config=self._parser_config,
ignore_whitespace=self._ignore_whitespace,
)
prompt_generator = BatchedCombinatorialPromptGenerator(
return BatchedCombinatorialPromptGenerator(
prompt_generator,
batches=self._combinatorial_batches,
)
else:
prompt_generator = RandomPromptGenerator(
self._wildcard_manager,
seed=self._seed,
parser_config=self._parser_config,
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
ignore_whitespace=self._ignore_whitespace,
)
return prompt_generator
return RandomPromptGenerator(
self._wildcard_manager,
seed=self._seed,
parser_config=self._parser_config,
unlink_seed_from_prompt=self._unlink_seed_from_prompt,
ignore_whitespace=self._ignore_whitespace,
)
def create_jinja_generator(self, p) -> PromptGenerator:
original_prompt = p.all_prompts[0] if len(p.all_prompts) > 0 else p.prompt

View File

@ -17,7 +17,7 @@ def get_seeds(
use_fixed_seed,
is_combinatorial=False,
combinatorial_batches=1,
):
) -> tuple[list[int], list[int]]:
if p.subseed_strength != 0:
seed = int(p.all_seeds[0])
subseed = int(p.all_subseeds[0])
@ -74,7 +74,7 @@ def generate_prompts(
prompt: str,
negative_prompt: str | None,
num_prompts: int,
seeds: list[int],
seeds: list[int] | None,
) -> tuple[list[str], list[str]]:
"""
Generate positive and negative prompts.

View File

@ -42,7 +42,7 @@ class PngInfoSaver:
return parameters
def strip_template_info(self, parameters: dict[str, Any]) -> str:
def strip_template_info(self, parameters: dict[str, Any]) -> dict[str, Any]:
if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]:
parameters["Prompt"] = (
parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip()