fix(ace15): handle missing lm_metadata in memory estimation during checkpoint export #12669 (#12686)

pull/12464/head^2
fappaz 2026-02-28 19:18:40 +13:00 committed by GitHub
parent 80d49441e5
commit 95e1059661
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module):
return getattr(self, self.lm_model).load_sd(sd) return getattr(self, self.lm_model).load_sd(sd)
def memory_estimation_function(self, token_weight_pairs, device=None): def memory_estimation_function(self, token_weight_pairs, device=None):
lm_metadata = token_weight_pairs["lm_metadata"] lm_metadata = token_weight_pairs.get("lm_metadata", {})
constant = self.constant constant = self.constant
if comfy.model_management.should_use_bf16(device): if comfy.model_management.should_use_bf16(device):
constant *= 0.5 constant *= 0.5
token_weight_pairs = token_weight_pairs.get("lm_prompt", []) token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens += lm_metadata['min_tokens'] num_tokens += lm_metadata.get("min_tokens", 0)
return num_tokens * constant * 1024 * 1024 return num_tokens * constant * 1024 * 1024
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):