google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How To Use max_pool In Flax

nikhilanayak opened this issue · comments

You would like to see a new example being implemented by the Flax core team or the community? Please let us know by filling the following template.

Description of the model to be implemented

I'm trying to use flax's max_pool function with a CNN.

Dataset the model could be trained on

Using jnp.zeros should work, the example isn't about the data itself, more about the shape of the data.

Reference implementations in other frameworks

When I do tf.keras.layers.MaxPooling3D(pool_size=(2,2,2)), it scales down all 3 axes of the data by 2. I want to achieve a similar result using flax.