stergiosba / kanx

Fast Kolmogorov-Arnold Network in JAX, initial experiments

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

KANX: Fast Implementation (Approximation) of Kolmogorov-Arnold Network in JAX

Work in progress

Introduction

Fast Kolmogorov-Arnold Network in JAX based on fast-kan using equinox.

The original implementation of KAN is pykan.

Installation

pip install .
pip install -r requirements.txt

Example

KANX comes with an example on MNIST:

python examples/train_mnist.py

Benchmark

We tested the implementation on MNIST and report the following wall-time for 3000 epochs:

Architecture Wall time (sec)
CPU (i5-1135G7) 130.51
CPU (i9-12900K) 67.85
GPU (RTX 3070 Ti) 13.55

Plots from the GPU experiment:

mlp_kan_compare

mlp_kan_compare

More experiments to come...

About

Fast Kolmogorov-Arnold Network in JAX, initial experiments


Languages

Language:Python 100.0%