Annotated MNIST: The implementation of `cross_entropy_loss` is incorrect
ayaka14732 opened this issue · comments
Currently the cross_entropy_loss
in the Annotated MNIST example is implemented as follows:
def cross_entropy_loss(*, logits, labels):
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
The correct implementation should be:
from operator import getitem
@jax.jit
def cross_entropy_loss(logits, labels):
logits = nn.log_softmax(logits)
loss = jax.vmap(getitem)(logits, labels)
loss = -loss.mean()
return loss
Test:
import jax
from jax import nn
import jax.numpy as np
jnp = np
import jax.random as rand
import numpy as onp
from operator import getitem
import torch
import torch.nn.functional as F
jax.config.update('jax_platform_name', 'cpu')
key = rand.PRNGKey(42)
x = rand.normal(key, (16, 10))
y = rand.randint(key, (16,), 0, 10)
x_ = torch.from_numpy(onp.asarray(x))
y_ = torch.from_numpy(onp.asarray(y)).long()
print(F.cross_entropy(x_, y_).numpy()) # PyTorch implementation: 2.888901
def cross_entropy_loss(*, logits, labels):
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
print(cross_entropy_loss(logits=x, labels=y)) # Current implementation: -0.0056607425
@jax.jit
def cross_entropy_loss(logits, labels):
logits = nn.log_softmax(logits)
loss = jax.vmap(getitem)(logits, labels)
loss = -loss.mean()
return loss
print(cross_entropy_loss(x, y)) # Correct implementation: 2.8889012
Thanks for opening this inquiry!
Note that the Annotated MNIST example already returns nn.log_softmax()
from the model:
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
Does this match your expecations?
@andsteing Actually I spotted this issue when I was investigating the quality of the implementations of cross entropy loss in JAX. Since Flax is one of the most popular NN libraries for JAX, and the Annotated MNIST is the default example in the Flax documentation, most users will just copy the code in this example into their projects, which will lead to wrong result.
For the implementation, I believe it is more common to normalise the logits in the loss function. This may not make things better for the MNIST example, but for large models, we usually adapt the model to various tasks by adding more layers on top of the model output. In such cases, the model output is not supposed to be normalised. Moreover, in popular NN libraries such as PyTorch, the cross entropy function expects the logits to be unnormalised.
Besides, the implementation has a performance issue because it converts the labels into one hot encoding. The correct implementation (using jax.vmap
) should be like the one I wrote above.
I am going write a blog article to comment on the implementations of cross entropy loss in JAX.
Besides, the implementation has a performance issue because it converts the labels into one hot encoding. The correct implementation (using jax.vmap) should be like the one I wrote above.
Whether a one_hot or vmap-of-index is faster depends on dimensions and platform. Typically the one-hot solution is very XLA friendly and easy to optimize it also makes it easier to support soft labels
Thanks for the additional details, @ayaka14732 !
I didn't write the original example, so I'm not sure what was the motivation at that point. Before changing anything, I'd like to discuss these topics in more detail though. Maybe @levskaya @jheek @mtthss want to chime in.
The open questions are:
- Should we rewrite Annotated MNIST to return non-normalized logits?
- Should the loss implementation use
vmap
instead of converting to one hot encoding?
(feel free to add more points with additional numbers)
As for 1, I'm not sure that adding the normalization step introduces any restrictions wrt what model can be built on top of the output of such a model. But I see the point about people copying the incorrect loss by accident. We could rename the loss to cross_entropy_loss_from_normalized_logits()
or remove the nn.log_softmax()
from the model output and use standard cross entropy loss (probably best to use optax.softmax_cross_entropy
in that case).
Besides, the implementation has a performance issue because it converts the labels into one hot encoding. The correct implementation (using jax.vmap) should be like the one I wrote above.
Whether a one_hot or vmap-of-index is faster depends on dimensions and platform. Typically the one-hot solution is very XLA friendly and easy to optimize it also makes it easier to support soft labels
@jheek Wow! I haven't noticed this.
I can confirm the performances are almost identical on CPU, GPU and TPU.
Test script:
import jax
from jax import nn
import jax.numpy as jnp
from operator import getitem
# jax.config.update('jax_platform_name', 'cpu')
x = jnp.array([[-1.18020773e-02, 2.12021172e-04, 1.54857263e-02, -2.69062817e-02, 2.48291939e-02, -1.57771826e-01, 1.92937106e-02, 8.00435692e-02, -3.22007202e-02, -9.29394960e-02],
[-7.89603218e-03, 1.77750289e-02, -4.46040742e-02, -6.05713278e-02, -2.72103678e-02, -8.59001800e-02, 9.83414203e-02, 7.55275413e-02, 2.83549689e-02, -1.53297465e-02],
[-1.39507279e-02, -2.87675336e-02, -6.21193871e-02, -1.39774792e-02, 5.88746369e-02, -1.25838742e-01, 5.08899838e-02, 3.97215132e-03, 5.62930740e-02, 1.32357161e-02],
[-2.65202336e-02, 4.41155769e-02, -3.11193094e-02, -2.00038608e-02, 2.85324380e-02, -6.48770258e-02, 8.85064006e-02, 1.30649462e-01, 7.65995309e-02, 1.04247741e-02],
[-4.47686166e-02, 1.18546493e-01, 2.52062641e-03, 5.05243726e-02, 3.30978036e-02, -1.43625394e-01, 6.45868406e-02, 2.03764439e-01, 3.71213034e-02, -1.07339084e-01],
[9.10869241e-03, 4.17448469e-02, -7.88132250e-02, -8.66309926e-03, 7.24829659e-02, -2.05243602e-01, 8.08372796e-02, 1.47233456e-01, 2.36027334e-02, -5.98792732e-03],
[-1.08229872e-02, -1.97286233e-02, -5.83424866e-02, -1.43472329e-02, 6.09962158e-02, -1.09460413e-01, 8.39179903e-02, -1.33977272e-02, 5.15223742e-02, -2.19145101e-02],
[-4.08720411e-02, 4.43317667e-02, 5.42337447e-03, -4.50615957e-03, -1.77606083e-02, -1.12251990e-01, 5.72303906e-02, 4.27244902e-02, 1.29846334e-02, -6.36750609e-02],
[-2.33511683e-02, -4.19441313e-02, -2.53099240e-02, -3.40913311e-02, 2.91062817e-02, -1.17673308e-01, 4.52934727e-02, 6.69038519e-02, 1.93193313e-02, 9.32972878e-04],
[-8.95771384e-03, -1.89850703e-02, -4.84827980e-02, 6.48929551e-02, 3.58220935e-02, -2.13164091e-01, 1.48798272e-01, 2.58073777e-01, 1.91853680e-02, -3.26515697e-02],
[4.63725962e-02, 2.34173127e-02, -1.20111704e-01, 3.28311101e-02, 3.72643247e-02, -3.24501187e-01, 1.72371522e-01, 1.61464572e-01, 3.20537128e-02, -6.84871152e-02],
[-5.78307211e-02, 9.58937407e-02, -4.97079864e-02, 2.08359621e-02, -1.53612345e-02, -1.99896753e-01, 1.58067212e-01, 1.99744061e-01, 2.05481090e-02, -8.61398652e-02],
[-3.60106044e-02, -2.96902433e-02, -1.36097893e-01, -8.01160187e-03, -3.87010686e-02, -1.14177659e-01, 1.08693138e-01, 7.54408836e-02, 1.20038494e-01, -4.08867598e-02],
[-4.42547724e-02, 2.65968330e-02, -1.16404049e-01, 3.99116948e-02, 3.70342433e-02, -2.52804637e-01, 1.20865166e-01, 1.91994041e-01, 1.31512567e-01, -5.34878895e-02],
[-1.04445070e-01, -1.25995502e-02, -3.07911113e-02, -2.90571153e-03, -6.47646636e-02, -7.31705278e-02, 5.22068962e-02, 8.08887705e-02, -5.31015918e-03, -1.43367410e-01],
[5.92986569e-02, 1.65142119e-05, -1.42436139e-02, -3.14359590e-02, 8.34358111e-02, -1.19490743e-01, 7.41921589e-02, 9.27548930e-02, -2.43497770e-02, -1.97705049e-02],
[5.49561530e-03, -1.56130977e-02, -4.27389368e-02, 4.68986072e-02, 6.23839833e-02, -2.25198716e-01, 9.64858308e-02, 1.23882063e-01, -7.69507885e-03, 5.38528115e-02],
[-2.27768160e-02, -5.26964292e-03, -2.93163955e-03, 6.27545714e-02, -1.27458237e-02, -2.02232912e-01, 1.19752392e-01, 1.84115887e-01, 4.13061678e-02, -1.24206990e-01],
[-1.88333243e-02, 6.86625615e-02, -2.11398304e-03, 8.18483531e-03, 8.05698037e-02, -1.55029133e-01, 8.26403499e-02, 2.03767478e-01, -2.83125415e-03, -3.98058593e-02],
[-1.26937091e-01, 7.40949661e-02, 4.41245735e-03, 1.11442685e-01, 1.80361643e-02, -2.52694398e-01, 1.81960687e-01, 2.68796027e-01, 2.17940658e-02, -1.80268660e-01],
[-1.80708244e-04, -2.12510191e-02, -7.65096098e-02, 7.51659274e-03, -9.01457481e-03, -1.21640250e-01, 9.88654792e-02, 8.79775360e-02, 7.35239983e-02, -7.77667202e-03],
[-4.61105630e-02, -4.42445502e-02, -6.50017634e-02, -8.33700225e-03, -3.89694385e-02, -9.54467058e-02, 7.01751038e-02, 9.09299105e-02, 4.63261157e-02, -7.00642243e-02],
[-9.73308459e-02, 8.19174573e-03, -7.40732402e-02, 2.13738866e-02, 8.34189821e-03, -1.76222131e-01, 1.13145962e-01, 1.78131312e-01, 6.00471720e-02, -1.10414103e-01],
[4.41795588e-02, 3.11944820e-02, -4.57068756e-02, 9.16974433e-03, 7.96715915e-02, -1.37639463e-01, 1.50509447e-01, 1.48718774e-01, 5.20974174e-02, 3.41124311e-02],
[3.03848479e-02, -2.04634238e-02, -5.18144295e-03, -7.55520612e-02, 9.20162909e-03, -1.17662966e-01, 7.82268792e-02, 1.07894555e-01, 2.12373957e-03, 1.11250095e-02],
[-2.81126983e-02, -8.13692510e-02, -1.24993384e-01, 1.56849958e-02, 8.87469202e-02, -1.63082436e-01, 1.06209114e-01, 1.34258643e-01, 6.51819557e-02, 5.11457995e-02],
[-1.58369094e-02, 2.42988374e-02, -8.84073824e-02, -5.85790724e-03, 9.41534638e-02, -2.78038144e-01, 9.62805152e-02, 1.56973749e-01, 3.26635092e-02, 1.13205947e-02],
[9.15303826e-03, 1.70634557e-02, -5.74565530e-02, -2.42397413e-02, 6.22252002e-02, -2.31550604e-01, 7.48545304e-02, 2.16238573e-01, -5.76255880e-02, -5.96193373e-02],
[1.40953138e-02, 2.27493793e-03, -3.40058431e-02, -2.38920618e-02, 6.47072494e-02, -1.27346247e-01, 1.30044132e-01, 4.72338349e-02, 4.18094844e-02, 1.18074603e-02],
[-8.56964849e-03, -2.67159306e-02, -5.96427359e-02, -2.79492103e-02, 4.75161448e-02, -1.16145663e-01, 9.01559144e-02, -2.28071418e-02, 5.35263345e-02, -1.84544232e-02],
[2.33726948e-03, 1.16003156e-01, -4.96490225e-02, 3.59052308e-02, 2.15967186e-02, -2.02133447e-01, 9.41444337e-02, 1.96441934e-01, 1.34492517e-01, -1.13289990e-01],
[-5.45443073e-02, -2.34426185e-03, -1.12764567e-01, -1.28814131e-02, -3.77303138e-02, -9.22404602e-02, 1.06957100e-01, 1.24134108e-01, 9.78331864e-02, -1.30443811e-01]])
y = jnp.array([9, 5, 1, 5, 0, 6, 1, 4, 8, 6, 2, 0, 7, 8, 7, 0, 2, 4, 0, 4, 7, 7, 4, 3, 6, 1, 6, 6, 8, 1, 0, 7])
@jax.jit
def f1(logits, labels):
logits = nn.log_softmax(logits)
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
@jax.jit
def f2(logits, labels):
logits = nn.log_softmax(logits)
loss = jax.vmap(getitem)(logits, labels)
loss = -loss.mean()
return loss
a = f1(x, y)
b = f2(x, y)
assert jnp.allclose(a, b)
import timeit
print(timeit.timeit('f1(x, y).block_until_ready()', globals=globals(), number=100000))
print(timeit.timeit('f2(x, y).block_until_ready()', globals=globals(), number=100000))
Since @jheek points out that the one-hot solution will not affect the performance, I am now thinking that optax.softmax_cross_entropy
is the best approach.
The jax.vmap(getitem)
is a nice trick! However, IMO something like:
n_classes = logits.shape[-1]
loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, n_classes)).mean()
is easier to understand for beginners and experts alike.
I am sorry that my previous benchmark was wrong. The arrays I used were not large enough, so the performance overhead was actually the communication between the host and the device rather than the actual computation.
Corrected test script:
import jax
from jax import nn
import jax.numpy as jnp
from operator import getitem
# jax.config.update('jax_platform_name', 'cpu')
x = [[-1.18020773e-02, 2.12021172e-04, 1.54857263e-02, -2.69062817e-02, 2.48291939e-02, -1.57771826e-01, 1.92937106e-02, 8.00435692e-02, -3.22007202e-02, -9.29394960e-02],
[-7.89603218e-03, 1.77750289e-02, -4.46040742e-02, -6.05713278e-02, -2.72103678e-02, -8.59001800e-02, 9.83414203e-02, 7.55275413e-02, 2.83549689e-02, -1.53297465e-02],
[-1.39507279e-02, -2.87675336e-02, -6.21193871e-02, -1.39774792e-02, 5.88746369e-02, -1.25838742e-01, 5.08899838e-02, 3.97215132e-03, 5.62930740e-02, 1.32357161e-02],
[-2.65202336e-02, 4.41155769e-02, -3.11193094e-02, -2.00038608e-02, 2.85324380e-02, -6.48770258e-02, 8.85064006e-02, 1.30649462e-01, 7.65995309e-02, 1.04247741e-02],
[-4.47686166e-02, 1.18546493e-01, 2.52062641e-03, 5.05243726e-02, 3.30978036e-02, -1.43625394e-01, 6.45868406e-02, 2.03764439e-01, 3.71213034e-02, -1.07339084e-01],
[9.10869241e-03, 4.17448469e-02, -7.88132250e-02, -8.66309926e-03, 7.24829659e-02, -2.05243602e-01, 8.08372796e-02, 1.47233456e-01, 2.36027334e-02, -5.98792732e-03],
[-1.08229872e-02, -1.97286233e-02, -5.83424866e-02, -1.43472329e-02, 6.09962158e-02, -1.09460413e-01, 8.39179903e-02, -1.33977272e-02, 5.15223742e-02, -2.19145101e-02],
[-4.08720411e-02, 4.43317667e-02, 5.42337447e-03, -4.50615957e-03, -1.77606083e-02, -1.12251990e-01, 5.72303906e-02, 4.27244902e-02, 1.29846334e-02, -6.36750609e-02],
[-2.33511683e-02, -4.19441313e-02, -2.53099240e-02, -3.40913311e-02, 2.91062817e-02, -1.17673308e-01, 4.52934727e-02, 6.69038519e-02, 1.93193313e-02, 9.32972878e-04],
[-8.95771384e-03, -1.89850703e-02, -4.84827980e-02, 6.48929551e-02, 3.58220935e-02, -2.13164091e-01, 1.48798272e-01, 2.58073777e-01, 1.91853680e-02, -3.26515697e-02],
[4.63725962e-02, 2.34173127e-02, -1.20111704e-01, 3.28311101e-02, 3.72643247e-02, -3.24501187e-01, 1.72371522e-01, 1.61464572e-01, 3.20537128e-02, -6.84871152e-02],
[-5.78307211e-02, 9.58937407e-02, -4.97079864e-02, 2.08359621e-02, -1.53612345e-02, -1.99896753e-01, 1.58067212e-01, 1.99744061e-01, 2.05481090e-02, -8.61398652e-02],
[-3.60106044e-02, -2.96902433e-02, -1.36097893e-01, -8.01160187e-03, -3.87010686e-02, -1.14177659e-01, 1.08693138e-01, 7.54408836e-02, 1.20038494e-01, -4.08867598e-02],
[-4.42547724e-02, 2.65968330e-02, -1.16404049e-01, 3.99116948e-02, 3.70342433e-02, -2.52804637e-01, 1.20865166e-01, 1.91994041e-01, 1.31512567e-01, -5.34878895e-02],
[-1.04445070e-01, -1.25995502e-02, -3.07911113e-02, -2.90571153e-03, -6.47646636e-02, -7.31705278e-02, 5.22068962e-02, 8.08887705e-02, -5.31015918e-03, -1.43367410e-01],
[5.92986569e-02, 1.65142119e-05, -1.42436139e-02, -3.14359590e-02, 8.34358111e-02, -1.19490743e-01, 7.41921589e-02, 9.27548930e-02, -2.43497770e-02, -1.97705049e-02],
[5.49561530e-03, -1.56130977e-02, -4.27389368e-02, 4.68986072e-02, 6.23839833e-02, -2.25198716e-01, 9.64858308e-02, 1.23882063e-01, -7.69507885e-03, 5.38528115e-02],
[-2.27768160e-02, -5.26964292e-03, -2.93163955e-03, 6.27545714e-02, -1.27458237e-02, -2.02232912e-01, 1.19752392e-01, 1.84115887e-01, 4.13061678e-02, -1.24206990e-01],
[-1.88333243e-02, 6.86625615e-02, -2.11398304e-03, 8.18483531e-03, 8.05698037e-02, -1.55029133e-01, 8.26403499e-02, 2.03767478e-01, -2.83125415e-03, -3.98058593e-02],
[-1.26937091e-01, 7.40949661e-02, 4.41245735e-03, 1.11442685e-01, 1.80361643e-02, -2.52694398e-01, 1.81960687e-01, 2.68796027e-01, 2.17940658e-02, -1.80268660e-01],
[-1.80708244e-04, -2.12510191e-02, -7.65096098e-02, 7.51659274e-03, -9.01457481e-03, -1.21640250e-01, 9.88654792e-02, 8.79775360e-02, 7.35239983e-02, -7.77667202e-03],
[-4.61105630e-02, -4.42445502e-02, -6.50017634e-02, -8.33700225e-03, -3.89694385e-02, -9.54467058e-02, 7.01751038e-02, 9.09299105e-02, 4.63261157e-02, -7.00642243e-02],
[-9.73308459e-02, 8.19174573e-03, -7.40732402e-02, 2.13738866e-02, 8.34189821e-03, -1.76222131e-01, 1.13145962e-01, 1.78131312e-01, 6.00471720e-02, -1.10414103e-01],
[4.41795588e-02, 3.11944820e-02, -4.57068756e-02, 9.16974433e-03, 7.96715915e-02, -1.37639463e-01, 1.50509447e-01, 1.48718774e-01, 5.20974174e-02, 3.41124311e-02],
[3.03848479e-02, -2.04634238e-02, -5.18144295e-03, -7.55520612e-02, 9.20162909e-03, -1.17662966e-01, 7.82268792e-02, 1.07894555e-01, 2.12373957e-03, 1.11250095e-02],
[-2.81126983e-02, -8.13692510e-02, -1.24993384e-01, 1.56849958e-02, 8.87469202e-02, -1.63082436e-01, 1.06209114e-01, 1.34258643e-01, 6.51819557e-02, 5.11457995e-02],
[-1.58369094e-02, 2.42988374e-02, -8.84073824e-02, -5.85790724e-03, 9.41534638e-02, -2.78038144e-01, 9.62805152e-02, 1.56973749e-01, 3.26635092e-02, 1.13205947e-02],
[9.15303826e-03, 1.70634557e-02, -5.74565530e-02, -2.42397413e-02, 6.22252002e-02, -2.31550604e-01, 7.48545304e-02, 2.16238573e-01, -5.76255880e-02, -5.96193373e-02],
[1.40953138e-02, 2.27493793e-03, -3.40058431e-02, -2.38920618e-02, 6.47072494e-02, -1.27346247e-01, 1.30044132e-01, 4.72338349e-02, 4.18094844e-02, 1.18074603e-02],
[-8.56964849e-03, -2.67159306e-02, -5.96427359e-02, -2.79492103e-02, 4.75161448e-02, -1.16145663e-01, 9.01559144e-02, -2.28071418e-02, 5.35263345e-02, -1.84544232e-02],
[2.33726948e-03, 1.16003156e-01, -4.96490225e-02, 3.59052308e-02, 2.15967186e-02, -2.02133447e-01, 9.41444337e-02, 1.96441934e-01, 1.34492517e-01, -1.13289990e-01],
[-5.45443073e-02, -2.34426185e-03, -1.12764567e-01, -1.28814131e-02, -3.77303138e-02, -9.22404602e-02, 1.06957100e-01, 1.24134108e-01, 9.78331864e-02, -1.30443811e-01]]
y = [9, 5, 1, 5, 0, 6, 1, 4, 8, 6, 2, 0, 7, 8, 7, 0, 2, 4, 0, 4, 7, 7, 4, 3, 6, 1, 6, 6, 8, 1, 0, 7]
x *= 10000
y *= 10000
x = jnp.array(x)
y = jnp.array(y)
@jax.jit
def f1(logits, labels):
logits = nn.log_softmax(logits)
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
@jax.jit
def f2(logits, labels):
logits = nn.log_softmax(logits)
loss = jax.vmap(getitem)(logits, labels)
loss = -loss.mean()
return loss
a = f1(x, y)
b = f2(x, y)
assert jnp.allclose(a, b)
import timeit
print(timeit.timeit('f1(x, y).block_until_ready()', globals=globals(), number=10000))
print(timeit.timeit('f2(x, y).block_until_ready()', globals=globals(), number=10000))
Results:
one-hot | vmap | |
---|---|---|
CPU | 78.54 | 46.45 |
GPU | 7.02 | 6.29 |
TPU | 3.96 | 124.31 |
@ayaka14732 thanks for the update. So It seems that on GPU/TPU the one-hot solution is preferred, and we can keep using the current implementation?
Is there anything actionable left in this issue, or do you think we can close it? Sorry if I missed anything!
It seems that on GPU/TPU the one-hot solution is preferred
On GPU, the one-hot solution is slower than vmap. However, on TPU, vmap is much slower than one-hot. So we will keep the one-hot solution.
Is there anything actionable left in this issue
There are still two things unresolved:
- Do not normalize logits inside the model. Instead, normalize it when calculating the loss.
- Remove the cross entropy loss function. Use the one in optax directly.
Resolved in #2071.
@ayaka14732 thanks for the discussion, please feel free to continue the discussion on the PR...