LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

grads become zeros after a short period of training on metal backend

michael8090 opened this issue · comments

code: https://github.com/michael8090/test-candle-metal/tree/tch-rs-grads

At line 123:

        // ISSUE HERE: use the line below with metal to see zero gradients...
        let mut pred_loss = Tensor::zeros(&[1], (Kind::Float, device));
        // With the line below, the issue goes away
        // let mut pred_loss;

When the variable is inited with a Tensor, the grads becomes zeros after about 5 iteration. But when I leave it uninitialized at the first place, the issue goes away.

When you run the code, it'll print about 40 times and you can see how the grads change.