hako-mikan 2025-01-19 18:49:52 +09:00
parent 5e9f00b6d1
commit dca1b990c0
1 changed files with 27 additions and 6 deletions

View File

@ -51,8 +51,8 @@ try:
except:
forge = False
revert_target = ""
orig_cache = 0
modelcache = collections.OrderedDict()
from inspect import currentframe
@ -170,6 +170,7 @@ def fake_checkpoint_info(checkpoint_info,metadata={},currentmodel=""):
sha256 = hashlib.sha256(json.dumps(metadata).encode("utf-8")).hexdigest()
checkpoint_info.sha256 = sha256
checkpoint_info.name_for_extra = currentmodel
checkpoint_info.isfake = True
checkpoint_info.name = checkpoint_info.name_for_extra + ".safetensors"
checkpoint_info.model_name = checkpoint_info.name_for_extra.replace("/", "_").replace("\\", "_")
@ -211,9 +212,9 @@ statistics = {"sum":{},"mean":{},"max":{},"min":{}}
def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,
useblocks,custom_name,save_sets,id_sets,wpresets,deep,fine,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks,main = [False,False,False]):
caster("merge start",hearm)
global hear,mergedmodel,stopmerge,statistics
global hear,mergedmodel,stopmerge,statistics, revert_target
stopmerge = False
debug = "debug" in save_sets
@ -221,6 +222,8 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
device = "cuda" if "use cuda" in save_sets else "cpu"
if forge:
fcinfo = sd_models.get_closet_checkpoint_match(shared.opts.sd_model_checkpoint)
revert_target = revert_target if hasattr(fcinfo, "isfake") else fcinfo
unload_forge()
else:
unload_model_weights(sd_models.model_data.sd_model)
@ -307,8 +310,8 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
caster("model load start",hearm)
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems,device)
theta_1=load_model_weights_m(model_b,2,cachetarget,device).copy()
theta_1 = load_model_weights_m(model_b,2,cachetarget,device).copy()
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_1.keys()
#adjust
@ -1534,10 +1537,26 @@ def cachedealer(start):
else:
shared.opts.sd_checkpoint_cache = orig_cache
def clearcache():
def clearcache(model_c):
global modelcache
del modelcache
modelcache = {}
if forge:
unload_forge()
unload_forge()
from modules.sd_models import forge_model_reload, model_data
from modules_forge.main_entry import forge_unet_storage_dtype_options
unet_storage_dtype, _ = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False))
forge_model_params = dict(
checkpoint_info=revert_target,
additional_modules=shared.opts.forge_additional_modules,
unet_storage_dtype=unet_storage_dtype
)
model_data.forge_hash = None
model_data.forge_loading_parameters = forge_model_params
forge_model_reload()
gc.collect()
devices.torch_gc()
@ -1668,6 +1687,8 @@ COMP_NAME_AND_PREFIX = {"transformer":PREFIX_M, "text_encoder": "clip_l" , "text
def forge_loader(state_dict, additional_state_dicts):
state_dicts, estimated_config = split_state_dict(state_dict, additional_state_dicts)
state_dict = None
del state_dict
repo_name = estimated_config.huggingface_repo