MorvanZhou / pytorch-A3C

Simple A3C implementation with pytorch + multiprocessing

Home Page:https://mofanpy.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

关于loss function的疑问

lizhihao6 opened this issue · comments

loss_function中的

exp_v = m.log_prob(a) * td.detach()

log_prob是[prob1, prob2, prob3]
td 是 [[value1, value2, value3]]
这样直接相乘得到的是一个二维矩阵,但是A2C里面不应该是对应步骤的A与对应的actor_loss相乘吗?
是否该改为

exp_v = m.log_prob(a) * td.detach()[0]

这是我更改loss函数后的reward,似乎具有更好的稳定性?

谢谢你指出问题,是的,td的确有维度问题,td的维度需要和a一样, 我根据这个修改了代码。

看到您修改的代码是

exp_v = m.log_prob(a) * td.detach().squeeze() 

想请教一下,使用squeeze和[0]有什么区别吗?

exp_v = m.log_prob(a) * td.detach()[0]

其实在这个例子中是没差别的。squeeze主要是将所有为1的维度去掉。比如(5,1,2) 变成(5,2)