Fix OOM regression in _apply() for quantized models during inference (#13372)

Skip unnecessary clone of inference-mode tensors when already inside
torch.inference_mode(), matching the existing guard in set_attr_param.
The unconditional clone introduced in 20561aa9 caused transient VRAM
doubling during model movement for FP8/quantized models.
pull/13392/head^2
Jun Yamog 2026-04-15 21:10:36 +12:00 committed by GitHub
parent 8f374716ee
commit 1de83f91c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if param is None:
continue
p = fn(param)
if p.is_inference():
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():