fix(security): Update the secure loading process of ModelCheckpoint

pull/17125/head
hujiayucc 2025-09-12 19:17:23 +08:00
parent 82a973c043
commit fd0893a166
1 changed files with 16 additions and 0 deletions

View File

@ -57,6 +57,9 @@ class RestrictedUnpickler(pickle.Unpickler):
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "pytorch_lightning.callbacks" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.ModelCheckpoint
if module == "__builtin__" and name == 'set':
return set
@ -153,6 +156,19 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
)
return None
# Add security global variable handling
try:
# Try importing the ModelCheckpoint from PyTorch Lightning
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
# Add security global variables
torch.serialization.add_safe_globals([ModelCheckpoint])
except ImportError:
# If the import fails, use a string representation
torch.serialization.add_safe_globals(['pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'])
except AttributeError:
# If the PyTorch version does not support add_safe_globals, ignore the error
pass
return unsafe_torch_load(filename, *args, **kwargs)