⚗️ pytorch.grad
Yonv1943 opened this issue · comments
YonV1943 曾伊言 commented
Get the gradient of network parameters:
import torch
class Net(torch.nn.Module):
def __init__(self, inp_dim=4, out_dim=2):
super().__init__()
self.net = torch.nn.Linear(inp_dim, out_dim)
def forward(self, inp):
return self.net(inp)
def run():
batch_size = 3
inp_dim = 4
out_dim = 2
net = Net(inp_dim, out_dim)
inp = torch.ones((batch_size, inp_dim))
out = net(inp)
assert out.shape == (batch_size, out_dim)
lab = torch.ones_like(out)
obj = torch.abs(out - lab).mean(dim=1)
assert obj.shape == (batch_size,)
# optimizer.zero_grad()
out.sum().backward()
# optimizer.step()
for param in net.parameters():
print(param.shape, param.grad)
"""print
torch.Size([2, 4]) tensor([[3., 3., 3., 3.], [3., 3., 3., 3.]])
torch.Size([2]) tensor([3., 3.])
"""
if __name__ == '__main__':
run()