LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

concat doesn't support gradient

michael8090 opened this issue · comments

In Pytorch, we can get the gradient out of a concated tensor:

>>> a = torch.tensor([1.0], requires_grad=True)
>>> b = torch.tensor([1.0], requires_grad=True)
>>> c = torch.cat((a, b))
>>> a.grad
>>> c.sum().backward()
>>> a.grad
tensor([1.])

But with tch-rs, it reports a is not a leaf tensor and has no gradient:

        let mut a = Tensor::from_slice(&[1.0]).reshape([1, 1]).requires_grad_(true);
        let mut a = Tensor::concat(&[a, Tensor::from_slice(&[1.0]).reshape([1, 1]).requires_grad_(true)], 1).requires_grad_(true);
        let b: Tensor = a.sum(Kind::Float);
        a.zero_grad();
        b.backward();
        println!("========={}=========", a.grad());

Output:

[W TensorBody.h:494] Warning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (function grad)
[W TensorBody.h:494] Warning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (function grad)
=========Tensor[Undefined]=========

Oh I made a mistake, please ignore the issue.