Set scheduler in args only if img2img class has it

main
Uminosachi 2024-06-08 19:43:10 +09:00
parent cac134a5fe
commit b7e105a6a9
1 changed files with 31 additions and 1 deletions

View File

@ -1,6 +1,8 @@
import copy
import importlib
import inspect
import os
from dataclasses import fields, is_dataclass
import modules.scripts as scripts
from modules import paths, shared
@ -149,6 +151,33 @@ def clear_controlnet_cache(cnet, input_scripts):
script.clear_control_model_cache()
def get_init_params(cls):
"""
Returns a list of parameter names for the __init__ method of a given class.
Works for both dataclasses and regular classes.
Args:
cls (type): The class to inspect.
Returns:
list: A list of parameter names for the __init__ method.
"""
if is_dataclass(cls):
# For dataclasses
return [field.name for field in fields(cls)]
elif inspect.isclass(cls):
# For regular classes
try:
signature = inspect.signature(cls.__init__)
except (ValueError, TypeError):
# Handle classes with no __init__ method
return []
# Exclude "self" and return the rest of the parameters
return [param.name for param in signature.parameters.values() if param.name != "self"]
else:
raise TypeError(f"Expected a class, got {type(cls).__name__}")
def get_sd_img2img_processing(init_image, mask_image, prompt, n_prompt, sampler_id, ddim_steps, cfg_scale, strength, seed,
mask_blur=4, fill_mode=1):
"""Get StableDiffusionProcessingImg2Img instance
@ -187,7 +216,6 @@ def get_sd_img2img_processing(init_image, mask_image, prompt, n_prompt, sampler_
negative_prompt=n_prompt,
seed=seed,
sampler_name=sampler_id,
scheduler="Automatic",
batch_size=1,
n_iter=1,
steps=ddim_steps,
@ -199,6 +227,8 @@ def get_sd_img2img_processing(init_image, mask_image, prompt, n_prompt, sampler_
do_not_save_samples=True,
do_not_save_grid=True,
)
if "scheduler" in get_init_params(StableDiffusionProcessingImg2Img):
sd_img2img_args["scheduler"] = "Automatic"
p = StableDiffusionProcessingImg2Img(**sd_img2img_args)