liutianlin0121 / jax-deep-models

A collection of deep learning models written in JAX/Flax.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Deep learning models in JAX and Flax

Welcome to this repository of deep learning models written in JAX and Flax! JAX is a numerical computing library capable of executing on various hardware accelerators, including CPUs, GPUs, and TPUs. Flax is built on top of JAX and provides a flexible way to train machine learning models.

This repository contains a collection of deep learning models, such as multilayer perceptrons, convolutional neural networks, and autoencoders. The training pipelines of these models are demonstrated on Google Colab.

This repository is inspired by Sebastian Raschka's Deep Learning Model Zoo, which is written in PyTorch and Tensorflow.

Multilayer Perceptron (MLP)

Title Dataset Notebooks
Basic MLP MNIST nbviewer     Open In Colab

Convolutional neural networks (ConvNets)

Title Dataset Notebooks
Basic ConvNet MNIST nbviewer     Open In Colab
Basic ConvNet CIFAR-10 nbviewer     Open In Colab
Basic ConvNet with dropout CIFAR-10 nbviewer     Open In Colab
Basic ConvNet with batchnorm CIFAR-10 nbviewer     Open In Colab
ResNet CIFAR-10 nbviewer     Open In Colab

Autoencoders

Title Dataset Notebooks
MLP autoencoder MNIST nbviewer     Open In Colab
Conv autoencoder MNIST nbviewer     Open In Colab
Variational MLP autoencoder MNIST nbviewer     Open In Colab
Variational Conv autoencoder MNIST nbviewer     Open In Colab

Acknowledgement

This repository includes code that has been adapted from various sources, including the Flax examples, the UvA DL tutorials, and the JAXopt examples.

Disclaimer

All notebooks in this repository are written for didactic purposes and are not intended to serve as performance benchmarks.

About

A collection of deep learning models written in JAX/Flax.


Languages

Language:Jupyter Notebook 100.0%