google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax.linen.Conv needs better error checking of 'padding' argument.

untom opened this issue · comments

Hi!

The following code leads to mysterious error message RuntimeError: UNKNOWN: -:4:130: error: expected '[' :

x = np.random.normal(size=(7, 48, 48, 96)).astype(np.float32)
model_def = nn.Conv(
    features=96, kernel_size=(7, 7),
    strides=(4, 4),
    padding=(3, 3))
model_state, conv_params = model_def.init({'params': jax.random.PRNGKey(42)}, x).pop('params')

out = model_def.apply({"params": conv_params}, x)

The mistake here is that I was using padding=(3, 3) instead of padding=((3, 3), (3, 3)), but the error message is not informative. It would be great if that could be improved. Ideally, a simpler padding spec. like padding=(3, 3) or even padding=3 could directly be supported.