LiyuanLucasLiu / ReinMax

Beyond Straight-Through

Home Page:https://arxiv.org/abs/2304.08612

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PyTorch PyPI - Python Version GitHub Maintenance PyPI

ReinMax

Beyond Straight-Through

Straight-ThroughReinMaxHow To UseExamplesCitationLicense

ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.

What is Straight-Through

Straight-Through (as below) bridges discrete variables (y_hard) and back-propagation.

y_soft = theta.softmax()

# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft) 

# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard 

It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:

  • y_soft - y_soft.detach(),
  • (theta/tau).softmax() - (theta/tau).softmax().detach(),
  • or what?

Understand Straight-Through and Go Beyond

We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads.

How to use?

reinmax can be installed via pip

pip install reinmax

To replace Straight-Through Gumbel-Softmax with ReinMax:

from reinmax import reinmax
...
- y_hard = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
+ y_hard, _ = reinmax(logits, tau) # note that reinmax prefers to set tau >= 1, while gumbel-softmax prefers to set tau < 1
...

To replace Straight-Through with ReinMax:

from reinmax import reinmax
...
- y_hard = one_hot_multinomial(logits.softmax()) 
- y_soft_tau = (logits/tau).softmax()
- y_hard = y_soft_tau - y_soft_tau.detach() + y_hard 
+ y_hard, y_soft = reinmax(logits, tau) 
...

Examples

Polynomial Programming

Following the previous study (Tucker et al., 2017; Grathwohl et al., 2018; Pervez et al., 2020; Paulus et al., 2021), let us start with a simple and classic problem, polynomial programming.

The implementation for this problem is available at the poly folder.

MNIST-VAE

We also benchmarked the performance by training variational auto-encoders (VAE) with categorical latent variables on MNIST.

The implementation for MNIST-VAE is available at the mnist_vae folder.

ListOps

For unsupervised parsing on ListOps, the implementation is available at the listops folder.

Citation

Please cite the following papers if you found our model useful. Thanks!

Liyuan Liu, Chengyu Dong, Xiaodong Liu, Bin Yu, and Jianfeng Gao (2023). Bridging Discrete and Backpropagation: Straight-Through and Beyond. ArXiv, abs/2304.08612.

@inproceedings{liu2023bridging,
  title={Bridging Discrete and Backpropagation: Straight-Through and Beyond},
  author = {Liu, Liyuan and Dong, Chengyu and Liu, Xiaodong and Yu, Bin and Gao, Jianfeng},
  booktitle = {arXiv:2304.08612 [cs]},
  year={2023}
}

About

Beyond Straight-Through

https://arxiv.org/abs/2304.08612


Languages

Language:Python 87.4%Language:Shell 12.6%