fix: allow model config override (#331)
parent
ecad77ed65
commit
f94b453d10
|
|
@ -430,10 +430,15 @@ class Script(scripts.Script):
|
||||||
state_dict = load_state_dict(model_path)
|
state_dict = load_state_dict(model_path)
|
||||||
network_module = PlugableControlModel
|
network_module = PlugableControlModel
|
||||||
network_config = shared.opts.data.get("control_net_model_config", default_conf)
|
network_config = shared.opts.data.get("control_net_model_config", default_conf)
|
||||||
|
|
||||||
if any([k.startswith("body.") for k, v in state_dict.items()]):
|
if any([k.startswith("body.") for k, v in state_dict.items()]):
|
||||||
# adapter model
|
# adapter model
|
||||||
network_module = PlugableAdapter
|
network_module = PlugableAdapter
|
||||||
network_config = shared.opts.data.get("control_net_model_adapter_config", default_conf_adapter)
|
network_config = shared.opts.data.get("control_net_model_adapter_config", default_conf_adapter)
|
||||||
|
|
||||||
|
override_config = os.path.splitext(model_path)[0] + ".yaml"
|
||||||
|
if os.path.exists(override_config):
|
||||||
|
network_config = override_config
|
||||||
|
|
||||||
network = network_module(
|
network = network_module(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue