Straight-Through • ReinMax • How To Use • Examples • Citation • License
ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.
Straight-Through (as below) bridges discrete variables (y_hard
) and back-propagation.
y_soft = theta.softmax()
# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft)
# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard
It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:
y_soft - y_soft.detach()
,(theta/tau).softmax() - (theta/tau).softmax().detach()
,- or what?
We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads.
reinmax
can be installed via pip
pip install reinmax
To replace Straight-Through Gumbel-Softmax with ReinMax:
from reinmax import reinmax
...
- y_hard = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
+ y_hard, _ = reinmax(logits, tau) # note that reinmax prefers to set tau >= 1, while gumbel-softmax prefers to set tau < 1
...
To replace Straight-Through with ReinMax:
from reinmax import reinmax
...
- y_hard = one_hot_multinomial(logits.softmax())
- y_soft_tau = (logits/tau).softmax()
- y_hard = y_soft_tau - y_soft_tau.detach() + y_hard
+ y_hard, y_soft = reinmax(logits, tau)
...
Following the previous study (Tucker et al., 2017; Grathwohl et al., 2018; Pervez et al., 2020; Paulus et al., 2021), let us start with a simple and classic problem, polynomial programming.
The implementation for this problem is available at the poly folder.
We also benchmarked the performance by training variational auto-encoders (VAE) with categorical latent variables on MNIST.
The implementation for MNIST-VAE is available at the mnist_vae folder.
For unsupervised parsing on ListOps, the implementation is available at the listops folder.
Please cite the following papers if you found our model useful. Thanks!
Liyuan Liu, Chengyu Dong, Xiaodong Liu, Bin Yu, and Jianfeng Gao (2023). Bridging Discrete and Backpropagation: Straight-Through and Beyond. ArXiv, abs/2304.08612.
@inproceedings{liu2023bridging,
title={Bridging Discrete and Backpropagation: Straight-Through and Beyond},
author = {Liu, Liyuan and Dong, Chengyu and Liu, Xiaodong and Yu, Bin and Gao, Jianfeng},
booktitle = {arXiv:2304.08612 [cs]},
year={2023}
}