How To Use max_pool In Flax
nikhilanayak opened this issue · comments
NikhilNayak commented
You would like to see a new example being implemented by the Flax core team or the community? Please let us know by filling the following template.
Description of the model to be implemented
I'm trying to use flax's max_pool
function with a CNN.
Dataset the model could be trained on
Using jnp.zeros
should work, the example isn't about the data itself, more about the shape of the data.
Reference implementations in other frameworks
When I do tf.keras.layers.MaxPooling3D(pool_size=(2,2,2))
, it scales down all 3 axes of the data by 2. I want to achieve a similar result using flax.