Habush / nam_jax

Jax-based implementation of Neural Additive Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Neural Additive Models in JAX

This repo contains JAX-based version of the model introduced in Neural Additive Models: Interpretable Machine Learning with Neural Nets by R. Agarwal et.al 2021.

NAM Architecture

Dependencies

  • jax
  • optax
  • haiku # used for implementing NN model
  • torch # used for creating mini-batches
  • numpy
  • scikit-learn

Examples

Checkout the nam_regression_example.ipynb notebook to see an example of using the model for the California housing Dataset

About

Jax-based implementation of Neural Additive Models

License:Apache License 2.0


Languages

Language:Jupyter Notebook 95.6%Language:Python 4.4%