pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Home Page:https://pytorch.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch.tensordot(-,-,0) no longer works

fritzo opened this issue Β· comments

πŸ› Bug

This check
https://github.com/pytorch/pytorch/pull/53672/files#diff-5f3d4caa0693a716fc46fd7f6339312f1b5f0bf89e3a3ff58e9dc13a9486b17aR1038-R1039
introduced in #53672 breaks the simple case tensordot(torch.tensor(0.), torch.tensor(0.), 0), which worked fine in PyTorch 1.8 and works fine in NumPy

>>> np.tensordot(np.zeros(()), np.zeros(()), 0)
array(0.)

This important edge case dims=0 is needed by generic machinery as in Pyro and opt_einsum. Can we revert this check?

cc @neerajprad

To Reproduce

>>> import torch
>>> torch.tensordot(torch.zeros(()), torch.zeros(()), 0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/fobermey/opt/miniconda3/envs/pyro/lib/python3.7/site-packages/torch/functional.py", line 929, in tensordot
    raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
RuntimeError: unsupported input to tensordot, got dims=0

Expected behavior

Behave as in NumPy and PyTorch 1.8

Environment

PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 10.15.7 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.2)
CMake version: version 3.18.4
Libc version: N/A

Python version: 3.7.0 (default, Jun 28 2018, 07:39:16)  [Clang 4.0.1 (tags/RELEASE_401/final)] (64-bit runtime)
Python platform: Darwin-19.6.0-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] gpytorch==1.5.0
[pip] numpy==1.19.4
[pip] numpyro==0.6.0
[pip] torch==1.9.0
[pip] torchfile==0.1.0
[pip] torchvision==0.10.0
[conda] gpytorch                  1.5.0                    pypi_0    pypi
[conda] numpy                     1.19.4                   pypi_0    pypi
[conda] numpyro                   0.6.0                     dev_0    <develop>
[conda] torch                     1.9.0                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchvision               0.10.0                   pypi_0    pypi

cc @gmagogsfm @eellison @neerajprad

@gmagogsfm I don't know much about tensorscriptability, but can you simply handle this case as follows?

diff --git a/torch/functional.py b/torch/functional.py
index acb32990d9..767caf7f76 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -948,6 +948,8 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None):  # noqa: F811
     if isinstance(dims, int):
         if dims < 0:
             raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
+        if dim == 0:
+            return a.reshape(a.size() + (1,) * b.dim()) * b
         dims_a = list(range(-dims, 0))
         dims_b = list(range(dims))

BC breaking

Thanks for reporting this issue, @fritzo; we should fix this ASAP.

cc @suo -- seems like this was due to some jit-related changes? Maybe we should follow-up offline on how to fix this?

Also cc @heitorschueroff as an fyi for how we might want to improve test coverage for operations like tensordot

Thanks for addressing this quickly! πŸŽ‰