fix(ace15): handle missing lm_metadata in memory estimation during checkpoint export #12669 (#12686)
parent
80d49441e5
commit
95e1059661
|
|
@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module):
|
|||
return getattr(self, self.lm_model).load_sd(sd)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
lm_metadata = token_weight_pairs.get("lm_metadata", {})
|
||||
constant = self.constant
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant *= 0.5
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens += lm_metadata['min_tokens']
|
||||
num_tokens += lm_metadata.get("min_tokens", 0)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue