ModelPatcherDynamic: force cast stray weights on comfy layers (#13487)

the mixed_precision ops can have input_scale parameters that are used
in tensor math but arent a weight or bias so dont get proper VRAM
management. Treat these as force-castable parameters like the non comfy
weight, random params are buffers already are.
pull/13471/merge
rattus 2026-04-23 08:13:38 +10:00 committed by GitHub
parent cb388e2912
commit ec4b1659ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 6 deletions

View File

@ -685,9 +685,9 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.patches:
if key not in self.patches and not force_cast:
return weight
inplace_update = self.weight_inplace_update or inplace_update
@ -695,7 +695,7 @@ class ModelPatcher:
if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
@ -703,9 +703,10 @@ class ModelPatcher:
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if key in self.patches:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if return_weight:
return out_weight
elif inplace_update:
@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
weight, _, _ = get_key_weight(self.model, key)
if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
for param in params:
if param not in ("weight", "bias"):
force_load_param(self, param, device_to)
else:
for param in params:
key = key_param_name_to_key(n, param)