fix(security): Update the secure loading process of ModelCheckpoint
parent
82a973c043
commit
fd0893a166
|
|
@ -57,6 +57,9 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if module == "pytorch_lightning.callbacks" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks
|
||||
return pytorch_lightning.callbacks.ModelCheckpoint
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
|
|
@ -153,6 +156,19 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|||
)
|
||||
return None
|
||||
|
||||
# Add security global variable handling
|
||||
try:
|
||||
# Try importing the ModelCheckpoint from PyTorch Lightning
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
# Add security global variables
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint])
|
||||
except ImportError:
|
||||
# If the import fails, use a string representation
|
||||
torch.serialization.add_safe_globals(['pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'])
|
||||
except AttributeError:
|
||||
# If the PyTorch version does not support add_safe_globals, ignore the error
|
||||
pass
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue