danijar / dreamerv2

Mastering Atari with Discrete World Models

Home Page:https://danijar.com/dreamerv2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Straight-thru gradients vs Gumbel Softmax

zplizzi opened this issue · comments

I'm curious if you considered trying the gumbel softmax as an alternative to the way you implemented straight-thru gradients in this paper/code. It seems like it might be a less-biased way of backpropagating through the operation of sampling from a categorical distribution. The "hard" variant allows you to retain a purely discrete one-hot output in the forward pass, as you did here.

As I understand it, you implemented:

  • forward: one_hot(draw(logits))
  • backward: softmax(logits, temp=1)

And the (hard version of the) gumbel softmax is:

  • forward: one_hot(arg_max(log_softmax(logits) + sample_from_gumbel_dist)
  • backward: softmax(log_softmax(logits) + sample_from_gumbel_dist), temp=temp_hyperparam)

The forwards in both versions are equivalent - the second is just a reparameterization of the first. By altering the temperature hyperparameter, you can trade off bias and variance.

Hey, I've tried gumble softmax back then but found it required careful temperature tuning and even temperature annealing to match the performance of straight-through gradients, which are easier and free of hparams. It's possible that gumble can be made to work better though, I'm not sure.