feat: parse env var strings to expected config value types (#3107)

* fix: add try_parse_bool for env var strings to enable config overrides of boolean values

* fix: fallback to given value if not parseable

* feat: extend eval to all valid types

* fix: remove return type

* fix: prevent strange type conversions by providing expected type

* feat: add tests
pull/3109/head
Manuel Schmid 2024-06-06 19:29:08 +02:00 committed by GitHub
parent 04d764820e
commit 5abae220c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 168 additions and 39 deletions

View File

@ -2,13 +2,14 @@ import os
import json import json
import math import math
import numbers import numbers
import args_manager import args_manager
import tempfile import tempfile
import modules.flags import modules.flags
import modules.sdxl_styles import modules.sdxl_styles
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.extra_utils import makedirs_with_log, get_files_from_folder from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var
from modules.flags import OutputFormat, Performance, MetadataScheme from modules.flags import OutputFormat, Performance, MetadataScheme
@ -200,7 +201,7 @@ path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/s
path_outputs = get_path_output() path_outputs = get_path_output()
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, expected_type=None):
global config_dict, visited_keys global config_dict, visited_keys
if key not in visited_keys: if key not in visited_keys:
@ -208,6 +209,7 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
v = os.getenv(key) v = os.getenv(key)
if v is not None: if v is not None:
v = try_eval_env_var(v, expected_type)
print(f"Environment: {key} = {v}") print(f"Environment: {key} = {v}")
config_dict[key] = v config_dict[key] = v
@ -252,41 +254,49 @@ temp_path = init_temp_path(get_config_item_or_set_default(
key='temp_path', key='temp_path',
default_value=default_temp_path, default_value=default_temp_path,
validator=lambda x: isinstance(x, str), validator=lambda x: isinstance(x, str),
expected_type=str
), default_temp_path) ), default_temp_path)
temp_path_cleanup_on_launch = get_config_item_or_set_default( temp_path_cleanup_on_launch = get_config_item_or_set_default(
key='temp_path_cleanup_on_launch', key='temp_path_cleanup_on_launch',
default_value=True, default_value=True,
validator=lambda x: isinstance(x, bool) validator=lambda x: isinstance(x, bool),
expected_type=bool
) )
default_base_model_name = default_model = get_config_item_or_set_default( default_base_model_name = default_model = get_config_item_or_set_default(
key='default_model', key='default_model',
default_value='model.safetensors', default_value='model.safetensors',
validator=lambda x: isinstance(x, str) validator=lambda x: isinstance(x, str),
expected_type=str
) )
previous_default_models = get_config_item_or_set_default( previous_default_models = get_config_item_or_set_default(
key='previous_default_models', key='previous_default_models',
default_value=[], default_value=[],
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x) validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x),
expected_type=list
) )
default_refiner_model_name = default_refiner = get_config_item_or_set_default( default_refiner_model_name = default_refiner = get_config_item_or_set_default(
key='default_refiner', key='default_refiner',
default_value='None', default_value='None',
validator=lambda x: isinstance(x, str) validator=lambda x: isinstance(x, str),
expected_type=str
) )
default_refiner_switch = get_config_item_or_set_default( default_refiner_switch = get_config_item_or_set_default(
key='default_refiner_switch', key='default_refiner_switch',
default_value=0.8, default_value=0.8,
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1 validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1,
expected_type=numbers.Number
) )
default_loras_min_weight = get_config_item_or_set_default( default_loras_min_weight = get_config_item_or_set_default(
key='default_loras_min_weight', key='default_loras_min_weight',
default_value=-2, default_value=-2,
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10,
expected_type=numbers.Number
) )
default_loras_max_weight = get_config_item_or_set_default( default_loras_max_weight = get_config_item_or_set_default(
key='default_loras_max_weight', key='default_loras_max_weight',
default_value=2, default_value=2,
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10,
expected_type=numbers.Number
) )
default_loras = get_config_item_or_set_default( default_loras = get_config_item_or_set_default(
key='default_loras', key='default_loras',
@ -320,38 +330,45 @@ default_loras = get_config_item_or_set_default(
validator=lambda x: isinstance(x, list) and all( validator=lambda x: isinstance(x, list) and all(
len(y) == 3 and isinstance(y[0], bool) and isinstance(y[1], str) and isinstance(y[2], numbers.Number) len(y) == 3 and isinstance(y[0], bool) and isinstance(y[1], str) and isinstance(y[2], numbers.Number)
or len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) or len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number)
for y in x) for y in x),
expected_type=list
) )
default_loras = [(y[0], y[1], y[2]) if len(y) == 3 else (True, y[0], y[1]) for y in default_loras] default_loras = [(y[0], y[1], y[2]) if len(y) == 3 else (True, y[0], y[1]) for y in default_loras]
default_max_lora_number = get_config_item_or_set_default( default_max_lora_number = get_config_item_or_set_default(
key='default_max_lora_number', key='default_max_lora_number',
default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5, default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5,
validator=lambda x: isinstance(x, int) and x >= 1 validator=lambda x: isinstance(x, int) and x >= 1,
expected_type=int
) )
default_cfg_scale = get_config_item_or_set_default( default_cfg_scale = get_config_item_or_set_default(
key='default_cfg_scale', key='default_cfg_scale',
default_value=7.0, default_value=7.0,
validator=lambda x: isinstance(x, numbers.Number) validator=lambda x: isinstance(x, numbers.Number),
expected_type=numbers.Number
) )
default_sample_sharpness = get_config_item_or_set_default( default_sample_sharpness = get_config_item_or_set_default(
key='default_sample_sharpness', key='default_sample_sharpness',
default_value=2.0, default_value=2.0,
validator=lambda x: isinstance(x, numbers.Number) validator=lambda x: isinstance(x, numbers.Number),
expected_type=numbers.Number
) )
default_sampler = get_config_item_or_set_default( default_sampler = get_config_item_or_set_default(
key='default_sampler', key='default_sampler',
default_value='dpmpp_2m_sde_gpu', default_value='dpmpp_2m_sde_gpu',
validator=lambda x: x in modules.flags.sampler_list validator=lambda x: x in modules.flags.sampler_list,
expected_type=str
) )
default_scheduler = get_config_item_or_set_default( default_scheduler = get_config_item_or_set_default(
key='default_scheduler', key='default_scheduler',
default_value='karras', default_value='karras',
validator=lambda x: x in modules.flags.scheduler_list validator=lambda x: x in modules.flags.scheduler_list,
expected_type=str
) )
default_vae = get_config_item_or_set_default( default_vae = get_config_item_or_set_default(
key='default_vae', key='default_vae',
default_value=modules.flags.default_vae, default_value=modules.flags.default_vae,
validator=lambda x: isinstance(x, str) validator=lambda x: isinstance(x, str),
expected_type=str
) )
default_styles = get_config_item_or_set_default( default_styles = get_config_item_or_set_default(
key='default_styles', key='default_styles',
@ -360,121 +377,144 @@ default_styles = get_config_item_or_set_default(
"Fooocus Enhance", "Fooocus Enhance",
"Fooocus Sharp" "Fooocus Sharp"
], ],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x),
expected_type=list
) )
default_prompt_negative = get_config_item_or_set_default( default_prompt_negative = get_config_item_or_set_default(
key='default_prompt_negative', key='default_prompt_negative',
default_value='', default_value='',
validator=lambda x: isinstance(x, str), validator=lambda x: isinstance(x, str),
disable_empty_as_none=True disable_empty_as_none=True,
expected_type=str
) )
default_prompt = get_config_item_or_set_default( default_prompt = get_config_item_or_set_default(
key='default_prompt', key='default_prompt',
default_value='', default_value='',
validator=lambda x: isinstance(x, str), validator=lambda x: isinstance(x, str),
disable_empty_as_none=True disable_empty_as_none=True,
expected_type=str
) )
default_performance = get_config_item_or_set_default( default_performance = get_config_item_or_set_default(
key='default_performance', key='default_performance',
default_value=Performance.SPEED.value, default_value=Performance.SPEED.value,
validator=lambda x: x in Performance.list() validator=lambda x: x in Performance.list(),
expected_type=str
) )
default_advanced_checkbox = get_config_item_or_set_default( default_advanced_checkbox = get_config_item_or_set_default(
key='default_advanced_checkbox', key='default_advanced_checkbox',
default_value=False, default_value=False,
validator=lambda x: isinstance(x, bool) validator=lambda x: isinstance(x, bool),
expected_type=bool
) )
default_max_image_number = get_config_item_or_set_default( default_max_image_number = get_config_item_or_set_default(
key='default_max_image_number', key='default_max_image_number',
default_value=32, default_value=32,
validator=lambda x: isinstance(x, int) and x >= 1 validator=lambda x: isinstance(x, int) and x >= 1,
expected_type=int
) )
default_output_format = get_config_item_or_set_default( default_output_format = get_config_item_or_set_default(
key='default_output_format', key='default_output_format',
default_value='png', default_value='png',
validator=lambda x: x in OutputFormat.list() validator=lambda x: x in OutputFormat.list(),
expected_type=str
) )
default_image_number = get_config_item_or_set_default( default_image_number = get_config_item_or_set_default(
key='default_image_number', key='default_image_number',
default_value=2, default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number,
expected_type=int
) )
checkpoint_downloads = get_config_item_or_set_default( checkpoint_downloads = get_config_item_or_set_default(
key='checkpoint_downloads', key='checkpoint_downloads',
default_value={}, default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
) )
lora_downloads = get_config_item_or_set_default( lora_downloads = get_config_item_or_set_default(
key='lora_downloads', key='lora_downloads',
default_value={}, default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
) )
embeddings_downloads = get_config_item_or_set_default( embeddings_downloads = get_config_item_or_set_default(
key='embeddings_downloads', key='embeddings_downloads',
default_value={}, default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
) )
available_aspect_ratios = get_config_item_or_set_default( available_aspect_ratios = get_config_item_or_set_default(
key='available_aspect_ratios', key='available_aspect_ratios',
default_value=modules.flags.sdxl_aspect_ratios, default_value=modules.flags.sdxl_aspect_ratios,
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1 validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1,
expected_type=list
) )
default_aspect_ratio = get_config_item_or_set_default( default_aspect_ratio = get_config_item_or_set_default(
key='default_aspect_ratio', key='default_aspect_ratio',
default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0], default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0],
validator=lambda x: x in available_aspect_ratios validator=lambda x: x in available_aspect_ratios,
expected_type=str
) )
default_inpaint_engine_version = get_config_item_or_set_default( default_inpaint_engine_version = get_config_item_or_set_default(
key='default_inpaint_engine_version', key='default_inpaint_engine_version',
default_value='v2.6', default_value='v2.6',
validator=lambda x: x in modules.flags.inpaint_engine_versions validator=lambda x: x in modules.flags.inpaint_engine_versions,
expected_type=str
) )
default_cfg_tsnr = get_config_item_or_set_default( default_cfg_tsnr = get_config_item_or_set_default(
key='default_cfg_tsnr', key='default_cfg_tsnr',
default_value=7.0, default_value=7.0,
validator=lambda x: isinstance(x, numbers.Number) validator=lambda x: isinstance(x, numbers.Number),
expected_type=numbers.Number
) )
default_clip_skip = get_config_item_or_set_default( default_clip_skip = get_config_item_or_set_default(
key='default_clip_skip', key='default_clip_skip',
default_value=2, default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max,
expected_type=int
) )
default_overwrite_step = get_config_item_or_set_default( default_overwrite_step = get_config_item_or_set_default(
key='default_overwrite_step', key='default_overwrite_step',
default_value=-1, default_value=-1,
validator=lambda x: isinstance(x, int) validator=lambda x: isinstance(x, int),
expected_type=int
) )
default_overwrite_switch = get_config_item_or_set_default( default_overwrite_switch = get_config_item_or_set_default(
key='default_overwrite_switch', key='default_overwrite_switch',
default_value=-1, default_value=-1,
validator=lambda x: isinstance(x, int) validator=lambda x: isinstance(x, int),
expected_type=int
) )
example_inpaint_prompts = get_config_item_or_set_default( example_inpaint_prompts = get_config_item_or_set_default(
key='example_inpaint_prompts', key='example_inpaint_prompts',
default_value=[ default_value=[
'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes' 'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes'
], ],
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
expected_type=list
) )
default_black_out_nsfw = get_config_item_or_set_default( default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw', key='default_black_out_nsfw',
default_value=False, default_value=False,
validator=lambda x: isinstance(x, bool) validator=lambda x: isinstance(x, bool),
expected_type=bool
) )
default_save_metadata_to_images = get_config_item_or_set_default( default_save_metadata_to_images = get_config_item_or_set_default(
key='default_save_metadata_to_images', key='default_save_metadata_to_images',
default_value=False, default_value=False,
validator=lambda x: isinstance(x, bool) validator=lambda x: isinstance(x, bool),
expected_type=bool
) )
default_metadata_scheme = get_config_item_or_set_default( default_metadata_scheme = get_config_item_or_set_default(
key='default_metadata_scheme', key='default_metadata_scheme',
default_value=MetadataScheme.FOOOCUS.value, default_value=MetadataScheme.FOOOCUS.value,
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x] validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x],
expected_type=str
) )
metadata_created_by = get_config_item_or_set_default( metadata_created_by = get_config_item_or_set_default(
key='metadata_created_by', key='metadata_created_by',
default_value='', default_value='',
validator=lambda x: isinstance(x, str) validator=lambda x: isinstance(x, str),
expected_type=str
) )
example_inpaint_prompts = [[x] for x in example_inpaint_prompts] example_inpaint_prompts = [[x] for x in example_inpaint_prompts]

View File

@ -1,4 +1,6 @@
import os import os
from ast import literal_eval
def makedirs_with_log(path): def makedirs_with_log(path):
try: try:
@ -24,3 +26,16 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None):
filenames.append(path) filenames.append(path)
return filenames return filenames
def try_eval_env_var(value: str, expected_type=None):
try:
value_eval = value
if expected_type is bool:
value_eval = value.title()
value_eval = literal_eval(value_eval)
if expected_type is not None and not isinstance(value_eval, expected_type):
return value
return value_eval
except:
return value

74
tests/test_extra_utils.py Normal file
View File

@ -0,0 +1,74 @@
import numbers
import os
import unittest
import modules.flags
from modules import extra_utils
class TestUtils(unittest.TestCase):
def test_try_eval_env_var(self):
test_cases = [
{
"input": ("foo", str),
"output": "foo"
},
{
"input": ("1", int),
"output": 1
},
{
"input": ("1.0", float),
"output": 1.0
},
{
"input": ("1", numbers.Number),
"output": 1
},
{
"input": ("1.0", numbers.Number),
"output": 1.0
},
{
"input": ("true", bool),
"output": True
},
{
"input": ("True", bool),
"output": True
},
{
"input": ("false", bool),
"output": False
},
{
"input": ("False", bool),
"output": False
},
{
"input": ("True", str),
"output": "True"
},
{
"input": ("False", str),
"output": "False"
},
{
"input": ("['a', 'b', 'c']", list),
"output": ['a', 'b', 'c']
},
{
"input": ("{'a':1}", dict),
"output": {'a': 1}
},
{
"input": ("('foo', 1)", tuple),
"output": ('foo', 1)
}
]
for test in test_cases:
value, expected_type = test["input"]
expected = test["output"]
actual = extra_utils.try_eval_env_var(value, expected_type)
self.assertEqual(expected, actual)