google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Annotated MNIST: The implementation of `cross_entropy_loss` is incorrect

ayaka14732 opened this issue · comments

commented

Currently the cross_entropy_loss in the Annotated MNIST example is implemented as follows:

def cross_entropy_loss(*, logits, labels):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
  return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

The correct implementation should be:

from operator import getitem

@jax.jit
def cross_entropy_loss(logits, labels):
    logits = nn.log_softmax(logits)
    loss = jax.vmap(getitem)(logits, labels)
    loss = -loss.mean()
    return loss

Test:

import jax
from jax import nn
import jax.numpy as np
jnp = np
import jax.random as rand
import numpy as onp
from operator import getitem
import torch
import torch.nn.functional as F

jax.config.update('jax_platform_name', 'cpu')

key = rand.PRNGKey(42)

x = rand.normal(key, (16, 10))
y = rand.randint(key, (16,), 0, 10)

x_ = torch.from_numpy(onp.asarray(x))
y_ = torch.from_numpy(onp.asarray(y)).long()

print(F.cross_entropy(x_, y_).numpy())  # PyTorch implementation: 2.888901

def cross_entropy_loss(*, logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

print(cross_entropy_loss(logits=x, labels=y))  # Current implementation: -0.0056607425

@jax.jit
def cross_entropy_loss(logits, labels):
    logits = nn.log_softmax(logits)
    loss = jax.vmap(getitem)(logits, labels)
    loss = -loss.mean()
    return loss

print(cross_entropy_loss(x, y))  # Correct implementation: 2.8889012

Thanks for opening this inquiry!

Note that the Annotated MNIST example already returns nn.log_softmax() from the model:

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

Does this match your expecations?

commented

@andsteing Actually I spotted this issue when I was investigating the quality of the implementations of cross entropy loss in JAX. Since Flax is one of the most popular NN libraries for JAX, and the Annotated MNIST is the default example in the Flax documentation, most users will just copy the code in this example into their projects, which will lead to wrong result.

For the implementation, I believe it is more common to normalise the logits in the loss function. This may not make things better for the MNIST example, but for large models, we usually adapt the model to various tasks by adding more layers on top of the model output. In such cases, the model output is not supposed to be normalised. Moreover, in popular NN libraries such as PyTorch, the cross entropy function expects the logits to be unnormalised.

Besides, the implementation has a performance issue because it converts the labels into one hot encoding. The correct implementation (using jax.vmap) should be like the one I wrote above.

I am going write a blog article to comment on the implementations of cross entropy loss in JAX.

commented

Besides, the implementation has a performance issue because it converts the labels into one hot encoding. The correct implementation (using jax.vmap) should be like the one I wrote above.

Whether a one_hot or vmap-of-index is faster depends on dimensions and platform. Typically the one-hot solution is very XLA friendly and easy to optimize it also makes it easier to support soft labels

Thanks for the additional details, @ayaka14732 !

I didn't write the original example, so I'm not sure what was the motivation at that point. Before changing anything, I'd like to discuss these topics in more detail though. Maybe @levskaya @jheek @mtthss want to chime in.

The open questions are:

  1. Should we rewrite Annotated MNIST to return non-normalized logits?
  2. Should the loss implementation use vmap instead of converting to one hot encoding?

(feel free to add more points with additional numbers)

As for 1, I'm not sure that adding the normalization step introduces any restrictions wrt what model can be built on top of the output of such a model. But I see the point about people copying the incorrect loss by accident. We could rename the loss to cross_entropy_loss_from_normalized_logits() or remove the nn.log_softmax() from the model output and use standard cross entropy loss (probably best to use optax.softmax_cross_entropy in that case).

commented

Besides, the implementation has a performance issue because it converts the labels into one hot encoding. The correct implementation (using jax.vmap) should be like the one I wrote above.

Whether a one_hot or vmap-of-index is faster depends on dimensions and platform. Typically the one-hot solution is very XLA friendly and easy to optimize it also makes it easier to support soft labels

@jheek Wow! I haven't noticed this.

I can confirm the performances are almost identical on CPU, GPU and TPU.

Test script:
import jax
from jax import nn
import jax.numpy as jnp
from operator import getitem

# jax.config.update('jax_platform_name', 'cpu')

x = jnp.array([[-1.18020773e-02, 2.12021172e-04, 1.54857263e-02, -2.69062817e-02, 2.48291939e-02, -1.57771826e-01, 1.92937106e-02, 8.00435692e-02, -3.22007202e-02, -9.29394960e-02],
              [-7.89603218e-03, 1.77750289e-02, -4.46040742e-02, -6.05713278e-02, -2.72103678e-02, -8.59001800e-02, 9.83414203e-02, 7.55275413e-02, 2.83549689e-02, -1.53297465e-02],
              [-1.39507279e-02, -2.87675336e-02, -6.21193871e-02, -1.39774792e-02, 5.88746369e-02, -1.25838742e-01, 5.08899838e-02, 3.97215132e-03, 5.62930740e-02, 1.32357161e-02],
              [-2.65202336e-02, 4.41155769e-02, -3.11193094e-02, -2.00038608e-02, 2.85324380e-02, -6.48770258e-02, 8.85064006e-02, 1.30649462e-01, 7.65995309e-02, 1.04247741e-02],
              [-4.47686166e-02, 1.18546493e-01, 2.52062641e-03, 5.05243726e-02, 3.30978036e-02, -1.43625394e-01, 6.45868406e-02, 2.03764439e-01, 3.71213034e-02, -1.07339084e-01],
              [9.10869241e-03, 4.17448469e-02, -7.88132250e-02, -8.66309926e-03, 7.24829659e-02, -2.05243602e-01, 8.08372796e-02, 1.47233456e-01, 2.36027334e-02, -5.98792732e-03],
              [-1.08229872e-02, -1.97286233e-02, -5.83424866e-02, -1.43472329e-02, 6.09962158e-02, -1.09460413e-01, 8.39179903e-02, -1.33977272e-02, 5.15223742e-02, -2.19145101e-02],
              [-4.08720411e-02, 4.43317667e-02, 5.42337447e-03, -4.50615957e-03, -1.77606083e-02, -1.12251990e-01, 5.72303906e-02, 4.27244902e-02, 1.29846334e-02, -6.36750609e-02],
              [-2.33511683e-02, -4.19441313e-02, -2.53099240e-02, -3.40913311e-02, 2.91062817e-02, -1.17673308e-01, 4.52934727e-02, 6.69038519e-02, 1.93193313e-02, 9.32972878e-04],
              [-8.95771384e-03, -1.89850703e-02, -4.84827980e-02, 6.48929551e-02, 3.58220935e-02, -2.13164091e-01, 1.48798272e-01, 2.58073777e-01, 1.91853680e-02, -3.26515697e-02],
              [4.63725962e-02, 2.34173127e-02, -1.20111704e-01, 3.28311101e-02, 3.72643247e-02, -3.24501187e-01, 1.72371522e-01, 1.61464572e-01, 3.20537128e-02, -6.84871152e-02],
              [-5.78307211e-02, 9.58937407e-02, -4.97079864e-02, 2.08359621e-02, -1.53612345e-02, -1.99896753e-01, 1.58067212e-01, 1.99744061e-01, 2.05481090e-02, -8.61398652e-02],
              [-3.60106044e-02, -2.96902433e-02, -1.36097893e-01, -8.01160187e-03, -3.87010686e-02, -1.14177659e-01, 1.08693138e-01, 7.54408836e-02, 1.20038494e-01, -4.08867598e-02],
              [-4.42547724e-02, 2.65968330e-02, -1.16404049e-01, 3.99116948e-02, 3.70342433e-02, -2.52804637e-01, 1.20865166e-01, 1.91994041e-01, 1.31512567e-01, -5.34878895e-02],
              [-1.04445070e-01, -1.25995502e-02, -3.07911113e-02, -2.90571153e-03, -6.47646636e-02, -7.31705278e-02, 5.22068962e-02, 8.08887705e-02, -5.31015918e-03, -1.43367410e-01],
              [5.92986569e-02, 1.65142119e-05, -1.42436139e-02, -3.14359590e-02, 8.34358111e-02, -1.19490743e-01, 7.41921589e-02, 9.27548930e-02, -2.43497770e-02, -1.97705049e-02],
              [5.49561530e-03, -1.56130977e-02, -4.27389368e-02, 4.68986072e-02, 6.23839833e-02, -2.25198716e-01, 9.64858308e-02, 1.23882063e-01, -7.69507885e-03, 5.38528115e-02],
              [-2.27768160e-02, -5.26964292e-03, -2.93163955e-03, 6.27545714e-02, -1.27458237e-02, -2.02232912e-01, 1.19752392e-01, 1.84115887e-01, 4.13061678e-02, -1.24206990e-01],
              [-1.88333243e-02, 6.86625615e-02, -2.11398304e-03, 8.18483531e-03, 8.05698037e-02, -1.55029133e-01, 8.26403499e-02, 2.03767478e-01, -2.83125415e-03, -3.98058593e-02],
              [-1.26937091e-01, 7.40949661e-02, 4.41245735e-03, 1.11442685e-01, 1.80361643e-02, -2.52694398e-01, 1.81960687e-01, 2.68796027e-01, 2.17940658e-02, -1.80268660e-01],
              [-1.80708244e-04, -2.12510191e-02, -7.65096098e-02, 7.51659274e-03, -9.01457481e-03, -1.21640250e-01, 9.88654792e-02, 8.79775360e-02, 7.35239983e-02, -7.77667202e-03],
              [-4.61105630e-02, -4.42445502e-02, -6.50017634e-02, -8.33700225e-03, -3.89694385e-02, -9.54467058e-02, 7.01751038e-02, 9.09299105e-02, 4.63261157e-02, -7.00642243e-02],
              [-9.73308459e-02, 8.19174573e-03, -7.40732402e-02, 2.13738866e-02, 8.34189821e-03, -1.76222131e-01, 1.13145962e-01, 1.78131312e-01, 6.00471720e-02, -1.10414103e-01],
              [4.41795588e-02, 3.11944820e-02, -4.57068756e-02, 9.16974433e-03, 7.96715915e-02, -1.37639463e-01, 1.50509447e-01, 1.48718774e-01, 5.20974174e-02, 3.41124311e-02],
              [3.03848479e-02, -2.04634238e-02, -5.18144295e-03, -7.55520612e-02, 9.20162909e-03, -1.17662966e-01, 7.82268792e-02, 1.07894555e-01, 2.12373957e-03, 1.11250095e-02],
              [-2.81126983e-02, -8.13692510e-02, -1.24993384e-01, 1.56849958e-02, 8.87469202e-02, -1.63082436e-01, 1.06209114e-01, 1.34258643e-01, 6.51819557e-02, 5.11457995e-02],
              [-1.58369094e-02, 2.42988374e-02, -8.84073824e-02, -5.85790724e-03, 9.41534638e-02, -2.78038144e-01, 9.62805152e-02, 1.56973749e-01, 3.26635092e-02, 1.13205947e-02],
              [9.15303826e-03, 1.70634557e-02, -5.74565530e-02, -2.42397413e-02, 6.22252002e-02, -2.31550604e-01, 7.48545304e-02, 2.16238573e-01, -5.76255880e-02, -5.96193373e-02],
              [1.40953138e-02, 2.27493793e-03, -3.40058431e-02, -2.38920618e-02, 6.47072494e-02, -1.27346247e-01, 1.30044132e-01, 4.72338349e-02, 4.18094844e-02, 1.18074603e-02],
              [-8.56964849e-03, -2.67159306e-02, -5.96427359e-02, -2.79492103e-02, 4.75161448e-02, -1.16145663e-01, 9.01559144e-02, -2.28071418e-02, 5.35263345e-02, -1.84544232e-02],
              [2.33726948e-03, 1.16003156e-01, -4.96490225e-02, 3.59052308e-02, 2.15967186e-02, -2.02133447e-01, 9.41444337e-02, 1.96441934e-01, 1.34492517e-01, -1.13289990e-01],
              [-5.45443073e-02, -2.34426185e-03, -1.12764567e-01, -1.28814131e-02, -3.77303138e-02, -9.22404602e-02, 1.06957100e-01, 1.24134108e-01, 9.78331864e-02, -1.30443811e-01]])
y = jnp.array([9, 5, 1, 5, 0, 6, 1, 4, 8, 6, 2, 0, 7, 8, 7, 0, 2, 4, 0, 4, 7, 7, 4, 3, 6, 1, 6, 6, 8, 1, 0, 7])

@jax.jit
def f1(logits, labels):
    logits = nn.log_softmax(logits)
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

@jax.jit
def f2(logits, labels):
    logits = nn.log_softmax(logits)
    loss = jax.vmap(getitem)(logits, labels)
    loss = -loss.mean()
    return loss

a = f1(x, y)
b = f2(x, y)
assert jnp.allclose(a, b)

import timeit
print(timeit.timeit('f1(x, y).block_until_ready()', globals=globals(), number=100000))
print(timeit.timeit('f2(x, y).block_until_ready()', globals=globals(), number=100000))
commented

Since @jheek points out that the one-hot solution will not affect the performance, I am now thinking that optax.softmax_cross_entropy is the best approach.

The jax.vmap(getitem) is a nice trick! However, IMO something like:

n_classes = logits.shape[-1]
loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, n_classes)).mean()

is easier to understand for beginners and experts alike.

commented

I am sorry that my previous benchmark was wrong. The arrays I used were not large enough, so the performance overhead was actually the communication between the host and the device rather than the actual computation.

Corrected test script:
import jax
from jax import nn
import jax.numpy as jnp
from operator import getitem

# jax.config.update('jax_platform_name', 'cpu')

x = [[-1.18020773e-02, 2.12021172e-04, 1.54857263e-02, -2.69062817e-02, 2.48291939e-02, -1.57771826e-01, 1.92937106e-02, 8.00435692e-02, -3.22007202e-02, -9.29394960e-02],
    [-7.89603218e-03, 1.77750289e-02, -4.46040742e-02, -6.05713278e-02, -2.72103678e-02, -8.59001800e-02, 9.83414203e-02, 7.55275413e-02, 2.83549689e-02, -1.53297465e-02],
    [-1.39507279e-02, -2.87675336e-02, -6.21193871e-02, -1.39774792e-02, 5.88746369e-02, -1.25838742e-01, 5.08899838e-02, 3.97215132e-03, 5.62930740e-02, 1.32357161e-02],
    [-2.65202336e-02, 4.41155769e-02, -3.11193094e-02, -2.00038608e-02, 2.85324380e-02, -6.48770258e-02, 8.85064006e-02, 1.30649462e-01, 7.65995309e-02, 1.04247741e-02],
    [-4.47686166e-02, 1.18546493e-01, 2.52062641e-03, 5.05243726e-02, 3.30978036e-02, -1.43625394e-01, 6.45868406e-02, 2.03764439e-01, 3.71213034e-02, -1.07339084e-01],
    [9.10869241e-03, 4.17448469e-02, -7.88132250e-02, -8.66309926e-03, 7.24829659e-02, -2.05243602e-01, 8.08372796e-02, 1.47233456e-01, 2.36027334e-02, -5.98792732e-03],
    [-1.08229872e-02, -1.97286233e-02, -5.83424866e-02, -1.43472329e-02, 6.09962158e-02, -1.09460413e-01, 8.39179903e-02, -1.33977272e-02, 5.15223742e-02, -2.19145101e-02],
    [-4.08720411e-02, 4.43317667e-02, 5.42337447e-03, -4.50615957e-03, -1.77606083e-02, -1.12251990e-01, 5.72303906e-02, 4.27244902e-02, 1.29846334e-02, -6.36750609e-02],
    [-2.33511683e-02, -4.19441313e-02, -2.53099240e-02, -3.40913311e-02, 2.91062817e-02, -1.17673308e-01, 4.52934727e-02, 6.69038519e-02, 1.93193313e-02, 9.32972878e-04],
    [-8.95771384e-03, -1.89850703e-02, -4.84827980e-02, 6.48929551e-02, 3.58220935e-02, -2.13164091e-01, 1.48798272e-01, 2.58073777e-01, 1.91853680e-02, -3.26515697e-02],
    [4.63725962e-02, 2.34173127e-02, -1.20111704e-01, 3.28311101e-02, 3.72643247e-02, -3.24501187e-01, 1.72371522e-01, 1.61464572e-01, 3.20537128e-02, -6.84871152e-02],
    [-5.78307211e-02, 9.58937407e-02, -4.97079864e-02, 2.08359621e-02, -1.53612345e-02, -1.99896753e-01, 1.58067212e-01, 1.99744061e-01, 2.05481090e-02, -8.61398652e-02],
    [-3.60106044e-02, -2.96902433e-02, -1.36097893e-01, -8.01160187e-03, -3.87010686e-02, -1.14177659e-01, 1.08693138e-01, 7.54408836e-02, 1.20038494e-01, -4.08867598e-02],
    [-4.42547724e-02, 2.65968330e-02, -1.16404049e-01, 3.99116948e-02, 3.70342433e-02, -2.52804637e-01, 1.20865166e-01, 1.91994041e-01, 1.31512567e-01, -5.34878895e-02],
    [-1.04445070e-01, -1.25995502e-02, -3.07911113e-02, -2.90571153e-03, -6.47646636e-02, -7.31705278e-02, 5.22068962e-02, 8.08887705e-02, -5.31015918e-03, -1.43367410e-01],
    [5.92986569e-02, 1.65142119e-05, -1.42436139e-02, -3.14359590e-02, 8.34358111e-02, -1.19490743e-01, 7.41921589e-02, 9.27548930e-02, -2.43497770e-02, -1.97705049e-02],
    [5.49561530e-03, -1.56130977e-02, -4.27389368e-02, 4.68986072e-02, 6.23839833e-02, -2.25198716e-01, 9.64858308e-02, 1.23882063e-01, -7.69507885e-03, 5.38528115e-02],
    [-2.27768160e-02, -5.26964292e-03, -2.93163955e-03, 6.27545714e-02, -1.27458237e-02, -2.02232912e-01, 1.19752392e-01, 1.84115887e-01, 4.13061678e-02, -1.24206990e-01],
    [-1.88333243e-02, 6.86625615e-02, -2.11398304e-03, 8.18483531e-03, 8.05698037e-02, -1.55029133e-01, 8.26403499e-02, 2.03767478e-01, -2.83125415e-03, -3.98058593e-02],
    [-1.26937091e-01, 7.40949661e-02, 4.41245735e-03, 1.11442685e-01, 1.80361643e-02, -2.52694398e-01, 1.81960687e-01, 2.68796027e-01, 2.17940658e-02, -1.80268660e-01],
    [-1.80708244e-04, -2.12510191e-02, -7.65096098e-02, 7.51659274e-03, -9.01457481e-03, -1.21640250e-01, 9.88654792e-02, 8.79775360e-02, 7.35239983e-02, -7.77667202e-03],
    [-4.61105630e-02, -4.42445502e-02, -6.50017634e-02, -8.33700225e-03, -3.89694385e-02, -9.54467058e-02, 7.01751038e-02, 9.09299105e-02, 4.63261157e-02, -7.00642243e-02],
    [-9.73308459e-02, 8.19174573e-03, -7.40732402e-02, 2.13738866e-02, 8.34189821e-03, -1.76222131e-01, 1.13145962e-01, 1.78131312e-01, 6.00471720e-02, -1.10414103e-01],
    [4.41795588e-02, 3.11944820e-02, -4.57068756e-02, 9.16974433e-03, 7.96715915e-02, -1.37639463e-01, 1.50509447e-01, 1.48718774e-01, 5.20974174e-02, 3.41124311e-02],
    [3.03848479e-02, -2.04634238e-02, -5.18144295e-03, -7.55520612e-02, 9.20162909e-03, -1.17662966e-01, 7.82268792e-02, 1.07894555e-01, 2.12373957e-03, 1.11250095e-02],
    [-2.81126983e-02, -8.13692510e-02, -1.24993384e-01, 1.56849958e-02, 8.87469202e-02, -1.63082436e-01, 1.06209114e-01, 1.34258643e-01, 6.51819557e-02, 5.11457995e-02],
    [-1.58369094e-02, 2.42988374e-02, -8.84073824e-02, -5.85790724e-03, 9.41534638e-02, -2.78038144e-01, 9.62805152e-02, 1.56973749e-01, 3.26635092e-02, 1.13205947e-02],
    [9.15303826e-03, 1.70634557e-02, -5.74565530e-02, -2.42397413e-02, 6.22252002e-02, -2.31550604e-01, 7.48545304e-02, 2.16238573e-01, -5.76255880e-02, -5.96193373e-02],
    [1.40953138e-02, 2.27493793e-03, -3.40058431e-02, -2.38920618e-02, 6.47072494e-02, -1.27346247e-01, 1.30044132e-01, 4.72338349e-02, 4.18094844e-02, 1.18074603e-02],
    [-8.56964849e-03, -2.67159306e-02, -5.96427359e-02, -2.79492103e-02, 4.75161448e-02, -1.16145663e-01, 9.01559144e-02, -2.28071418e-02, 5.35263345e-02, -1.84544232e-02],
    [2.33726948e-03, 1.16003156e-01, -4.96490225e-02, 3.59052308e-02, 2.15967186e-02, -2.02133447e-01, 9.41444337e-02, 1.96441934e-01, 1.34492517e-01, -1.13289990e-01],
    [-5.45443073e-02, -2.34426185e-03, -1.12764567e-01, -1.28814131e-02, -3.77303138e-02, -9.22404602e-02, 1.06957100e-01, 1.24134108e-01, 9.78331864e-02, -1.30443811e-01]]
y = [9, 5, 1, 5, 0, 6, 1, 4, 8, 6, 2, 0, 7, 8, 7, 0, 2, 4, 0, 4, 7, 7, 4, 3, 6, 1, 6, 6, 8, 1, 0, 7]

x *= 10000
y *= 10000

x = jnp.array(x)
y = jnp.array(y)

@jax.jit
def f1(logits, labels):
    logits = nn.log_softmax(logits)
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

@jax.jit
def f2(logits, labels):
    logits = nn.log_softmax(logits)
    loss = jax.vmap(getitem)(logits, labels)
    loss = -loss.mean()
    return loss

a = f1(x, y)
b = f2(x, y)
assert jnp.allclose(a, b)

import timeit
print(timeit.timeit('f1(x, y).block_until_ready()', globals=globals(), number=10000))
print(timeit.timeit('f2(x, y).block_until_ready()', globals=globals(), number=10000))

Results:

one-hot vmap
CPU 78.54 46.45
GPU 7.02 6.29
TPU 3.96 124.31

@ayaka14732 thanks for the update. So It seems that on GPU/TPU the one-hot solution is preferred, and we can keep using the current implementation?

Is there anything actionable left in this issue, or do you think we can close it? Sorry if I missed anything!

commented

It seems that on GPU/TPU the one-hot solution is preferred

On GPU, the one-hot solution is slower than vmap. However, on TPU, vmap is much slower than one-hot. So we will keep the one-hot solution.

Is there anything actionable left in this issue

There are still two things unresolved:

  1. Do not normalize logits inside the model. Instead, normalize it when calculating the loss.
  2. Remove the cross entropy loss function. Use the one in optax directly.

Resolved in #2071.

@ayaka14732 thanks for the discussion, please feel free to continue the discussion on the PR...