`torch.load` overwrites `pickle_module` parameter indescriminately
lopho opened this issue Β· comments
π Describe the bug
The pickle_module
parameter of torch.load
is always overriden with pickle
.
Minimal example:
import torch
import pickle
class CustomUnpickler(pickle.Unpickler):
def load(self):
raise
class CustomPickle:
Unpickler = CustomUnpickler
torch.save({'nothing', None}, 'test_pickle.pt')
torch.load('test_pickle.pt', pickle_module=CustomPickle)
This does not raise, but must, if the custom unpickler is being used.
Specifically commit 895557d in PR #86812 adds the following, which logically always either overwrites pickle_module=pickle
or raises a RuntimeError
.
pytorch/torch/serialization.py
Lines 762 to 769 in 2bda2ba
Versions
Collecting environment information...
PyTorch version: 1.13.0+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31
Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.10.0-18-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A6000
Nvidia driver version: 520.61.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.13.0+cu116
[pip3] torchaudio==0.13.0+cu116
[pip3] torchtext==0.14.0
[pip3] torchvision==0.14.0+cu116
[conda] Could not collect
cc @mruberry
@malfet Tagging you as you authored the linked commit!