flax.linen.Conv needs better error checking of 'padding' argument.
untom opened this issue · comments
Thomas Unterthiner commented
Hi!
The following code leads to mysterious error message RuntimeError: UNKNOWN: -:4:130: error: expected '['
:
x = np.random.normal(size=(7, 48, 48, 96)).astype(np.float32)
model_def = nn.Conv(
features=96, kernel_size=(7, 7),
strides=(4, 4),
padding=(3, 3))
model_state, conv_params = model_def.init({'params': jax.random.PRNGKey(42)}, x).pop('params')
out = model_def.apply({"params": conv_params}, x)
The mistake here is that I was using padding=(3, 3)
instead of padding=((3, 3), (3, 3))
, but the error message is not informative. It would be great if that could be improved. Ideally, a simpler padding spec. like padding=(3, 3)
or even padding=3
could directly be supported.