Intro to Jax for the Bamb!2022 summer school.
This repo contains Python notebooks training simple models via (stochastic) gradient descent. The example notebooks use the Python library JAX and its neural network library stax
Models include:
- linear regression.
- logistic regression.
- image classification with neural networks.
- image classification with convolutional neural networks.
- Numpy's intro to the Matlab user.
- A Numpy cheat sheet for the Matlab user.
- Lecture notes of a course (I once taught) on Python for scientific computing.
You need python>=3.7
you can install all the dependencies using
pip install -r requirements.txt