From b7e105a6a9d46b715a134eef083282bcc90b0a3c Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sat, 8 Jun 2024 19:43:10 +0900 Subject: [PATCH] Set scheduler in args only if img2img class has it --- ia_webui_controlnet.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/ia_webui_controlnet.py b/ia_webui_controlnet.py index c8d39e2..2d8ef3e 100644 --- a/ia_webui_controlnet.py +++ b/ia_webui_controlnet.py @@ -1,6 +1,8 @@ import copy import importlib +import inspect import os +from dataclasses import fields, is_dataclass import modules.scripts as scripts from modules import paths, shared @@ -149,6 +151,33 @@ def clear_controlnet_cache(cnet, input_scripts): script.clear_control_model_cache() +def get_init_params(cls): + """ + Returns a list of parameter names for the __init__ method of a given class. + Works for both dataclasses and regular classes. + + Args: + cls (type): The class to inspect. + + Returns: + list: A list of parameter names for the __init__ method. + """ + if is_dataclass(cls): + # For dataclasses + return [field.name for field in fields(cls)] + elif inspect.isclass(cls): + # For regular classes + try: + signature = inspect.signature(cls.__init__) + except (ValueError, TypeError): + # Handle classes with no __init__ method + return [] + # Exclude "self" and return the rest of the parameters + return [param.name for param in signature.parameters.values() if param.name != "self"] + else: + raise TypeError(f"Expected a class, got {type(cls).__name__}") + + def get_sd_img2img_processing(init_image, mask_image, prompt, n_prompt, sampler_id, ddim_steps, cfg_scale, strength, seed, mask_blur=4, fill_mode=1): """Get StableDiffusionProcessingImg2Img instance @@ -187,7 +216,6 @@ def get_sd_img2img_processing(init_image, mask_image, prompt, n_prompt, sampler_ negative_prompt=n_prompt, seed=seed, sampler_name=sampler_id, - scheduler="Automatic", batch_size=1, n_iter=1, steps=ddim_steps, @@ -199,6 +227,8 @@ def get_sd_img2img_processing(init_image, mask_image, prompt, n_prompt, sampler_ do_not_save_samples=True, do_not_save_grid=True, ) + if "scheduler" in get_init_params(StableDiffusionProcessingImg2Img): + sd_img2img_args["scheduler"] = "Automatic" p = StableDiffusionProcessingImg2Img(**sd_img2img_args)