pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Home Page:https://pytorch.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`torch.load` overwrites `pickle_module` parameter indescriminately

lopho opened this issue Β· comments

commented

πŸ› Describe the bug

The pickle_module parameter of torch.load is always overriden with pickle.
Minimal example:

import torch
import pickle
class CustomUnpickler(pickle.Unpickler):
    def load(self):
        raise
class CustomPickle:
    Unpickler = CustomUnpickler
torch.save({'nothing', None}, 'test_pickle.pt')
torch.load('test_pickle.pt', pickle_module=CustomPickle)

This does not raise, but must, if the custom unpickler is being used.
Specifically commit 895557d in PR #86812 adds the following, which logically always either overwrites pickle_module=pickle or raises a RuntimeError.

if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
weights_only = True
if weights_only:
if pickle_module is not None:
raise RuntimeError("Can not safely load weights when expiclit picke_module is specified")
else:
pickle_module = pickle

Versions

Collecting environment information...
PyTorch version: 1.13.0+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.10.0-18-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A6000
Nvidia driver version: 520.61.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.13.0+cu116
[pip3] torchaudio==0.13.0+cu116
[pip3] torchtext==0.14.0
[pip3] torchvision==0.14.0+cu116
[conda] Could not collect

cc @mruberry

@malfet Tagging you as you authored the linked commit!

commented

fixed by #88570