astanziola / siren-flax

SIREN neural networks in Flax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SIREN in Flax

Unofficial implementation of SIREN neural networks in Flax, using the Linen Module system.

This repo also includes Modulated Periodic Activations for Generalizable Local Functional Representations.

Examples

An image fitting problem is provided in the Example notebook

reults

Defining a single SIREN layer

Returns a fully connected layer with sinusoidal activation function, initialized according to the original SIREN paper.

layer = SirenLayer(
    features = 32
    w0 = 1.0
    c = 6.0
    is_first = False
    use_bias = True
    act = jnp.sin
    precision = None
    dtype = jnp.float32
)

How to use a SIREN neural network

SirenNN = Siren(hidden_dim=512, output_dim=1, final_activation=sigmoid)
params = SirenNN.init(random_key, sample_input)["params"]
output = SirenNN.apply({"params": params}, sample_input)

Approximate image on a grid

This can be easily done using the built-in broadcasting features of jax.numpy functions. This repository provides an useful initializer grid_init to generate a coordinate grid that can be used as input.

SirenDef = Siren(num_layers=5)

grid = grid_init(grid_dimension, jnp.float32)()
params = SirenDef.init(key, grid)["params"]

image = SirenDef.apply({"params": params}, grid)

Use Modulated SIREN

SirenDef = ModulatedSiren(num_layers=5)

grid = grid_init(grid_dimension, jnp.float32)()
params = SirenDef.init(key, grid)["params"]

image = SirenDef.apply({"params": params}, grid)

References

  1. Implicit Neural Representations with Periodic Activation Functions
  2. Modulated Periodic Activations for Generalizable Local Functional Representations

Related works

About

SIREN neural networks in Flax


Languages

Language:Jupyter Notebook 99.2%Language:Python 0.8%