Add HOWTO explaining how to load MNIST from torchvision and HuggingFace datasets
marcvanzee opened this issue · comments
Marc van Zee commented
According to #325, torch
is as easy as replacing
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
With
def get_datasets():
"""Load MNIST train and test datasets into memory."""
train_ds = torchvision.datasets.MNIST('./data', train=True, download=True)
test_ds = torchvision.datasets.MNIST('./data', train=False, download=True)
train_ds = {'image': onp.expand_dims(train_ds.data.numpy(), 3),
'label': train_ds.targets.numpy()}
test_ds = {'image': onp.expand_dims(test_ds.data.numpy(), 3),
'label': test_ds.targets.numpy()}
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
For HuggingFace it should not be much more difficult.
Also we could like to this JAX tutorial, which is more general: https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
Marc van Zee commented
Closing this in favor of #2116.