google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add HOWTO explaining how to load MNIST from torchvision and HuggingFace datasets

marcvanzee opened this issue · comments

According to #325, torch is as easy as replacing

def get_datasets():
   """Load MNIST train and test datasets into memory."""
   ds_builder = tfds.builder('mnist')
   ds_builder.download_and_prepare()
   train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
   test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
   train_ds['image'] = jnp.float32(train_ds['image']) / 255.
   test_ds['image'] = jnp.float32(test_ds['image']) / 255.
   return train_ds, test_ds

With

def get_datasets():
   """Load MNIST train and test datasets into memory."""
   train_ds = torchvision.datasets.MNIST('./data', train=True, download=True)
   test_ds = torchvision.datasets.MNIST('./data', train=False, download=True)
   train_ds = {'image': onp.expand_dims(train_ds.data.numpy(), 3), 
               'label': train_ds.targets.numpy()}
   test_ds = {'image': onp.expand_dims(test_ds.data.numpy(), 3), 
              'label': test_ds.targets.numpy()}
   train_ds['image'] = jnp.float32(train_ds['image']) / 255.
   test_ds['image'] = jnp.float32(test_ds['image']) / 255.
   return train_ds, test_ds

For HuggingFace it should not be much more difficult.

Also we could like to this JAX tutorial, which is more general: https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

Closing this in favor of #2116.