torch.tensordot(-,-,0) no longer works
fritzo opened this issue Β· comments
π Bug
This check
https://github.com/pytorch/pytorch/pull/53672/files#diff-5f3d4caa0693a716fc46fd7f6339312f1b5f0bf89e3a3ff58e9dc13a9486b17aR1038-R1039
introduced in #53672 breaks the simple case tensordot(torch.tensor(0.), torch.tensor(0.), 0)
, which worked fine in PyTorch 1.8 and works fine in NumPy
>>> np.tensordot(np.zeros(()), np.zeros(()), 0)
array(0.)
This important edge case dims=0
is needed by generic machinery as in Pyro and opt_einsum. Can we revert this check?
cc @neerajprad
To Reproduce
>>> import torch
>>> torch.tensordot(torch.zeros(()), torch.zeros(()), 0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/fobermey/opt/miniconda3/envs/pyro/lib/python3.7/site-packages/torch/functional.py", line 929, in tensordot
raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
RuntimeError: unsupported input to tensordot, got dims=0
Expected behavior
Behave as in NumPy and PyTorch 1.8
Environment
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 10.15.7 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.2)
CMake version: version 3.18.4
Libc version: N/A
Python version: 3.7.0 (default, Jun 28 2018, 07:39:16) [Clang 4.0.1 (tags/RELEASE_401/final)] (64-bit runtime)
Python platform: Darwin-19.6.0-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip] gpytorch==1.5.0
[pip] numpy==1.19.4
[pip] numpyro==0.6.0
[pip] torch==1.9.0
[pip] torchfile==0.1.0
[pip] torchvision==0.10.0
[conda] gpytorch 1.5.0 pypi_0 pypi
[conda] numpy 1.19.4 pypi_0 pypi
[conda] numpyro 0.6.0 dev_0 <develop>
[conda] torch 1.9.0 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchvision 0.10.0 pypi_0 pypi
@gmagogsfm I don't know much about tensorscriptability, but can you simply handle this case as follows?
diff --git a/torch/functional.py b/torch/functional.py
index acb32990d9..767caf7f76 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -948,6 +948,8 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
if isinstance(dims, int):
if dims < 0:
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
+ if dim == 0:
+ return a.reshape(a.size() + (1,) * b.dim()) * b
dims_a = list(range(-dims, 0))
dims_b = list(range(dims))
BC breaking
Thanks for reporting this issue, @fritzo; we should fix this ASAP.
cc @suo -- seems like this was due to some jit-related changes? Maybe we should follow-up offline on how to fix this?
Also cc @heitorschueroff as an fyi for how we might want to improve test coverage for operations like tensordot
Thanks for addressing this quickly! π