patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Conv2dTranspose possible bug??

haroldle opened this issue · comments

Dear Equinox team,
I think I just found a bug in Conv2dTranspose.

  • Hypothesis:
  1. Given same kernel weights for Conv2d and Conv2dTranspose as well as bias for both layers, the output of Conv2d and Conv2dTranspose should be different from each other due to different operation.
  • Actual:
  1. In Equinox, both layers give the same output. (Against hypothesis)
  2. In Pytorch, both layers' outputs are different from each other. (Support hypothesis)
  3. In Keras 3.0 with Jax as backend, both layers' outputs are different from each other. (Support hypothesis)

I use the lastest version of Equinox, Keras, Pytorch, Jax, and Jaxlib
for numpy I used the version 1.26.4
I attach a google colab link to support my claim.
https://colab.research.google.com/drive/18TIlVnUClwlZ6MzzVR0LmOuaPqMQ_rwH?usp=sharing

That's because you have to permute the weights, just like you did with torch/keras. See the FAQ section of https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose for how to do that

Dear @lockwo and Equinox team,
Thank you for replying my issue. I am still a newbie in Deep Learning.
Just FYI, the code section in FAQ section of https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose is wrong. There are no such thing eqx.Conv or eqx.ConvTranspose in Equinox it should be eqx.nn.Conv and eqx.nn.ConvTranspose
I followed the section and it gave me new error.
Apparently, swapping the axes changes the input dimension and output dimension of the cnn_t. In pytorch or keras with jax they do not behave like that.
cnn_t = eqx.tree_at(lambda x: x.weight, cnn_t, jnp.flip(cnn.weight, axis=tuple(range(2, cnn.weight.ndim))).swapaxes(0, 1))
This is the google colab link for this issue:
https://colab.research.google.com/drive/18TIlVnUClwlZ6MzzVR0LmOuaPqMQ_rwH?usp=sharing

That specific code is for the exact transpose operation. Which means it would map a 3 channel to a 2 channel (opposite of the forward convolution), so for a 2 channel input it errors. I'm not totally sure the operation you are trying to accomplish, but that direction of reshaping is what I thought was necessary (not exactly that set of reshapes per se). You can see more info on transpose testing here: #728.

I'm not totally sure how PyTorch works, so I'm not sure why/how it is computing the forward pass, but in general the transpose should do this reverse channel, so idk why it doesn't error (keras/torch often try to help the user much more, so they might be fixing/implicitly doing something hidden)

I see.
To my understanding, in Pytorch, when I creates the Convolution Transpose and Convolution layers, the kernel weight in Convolution Transpose has a dimension of (In_dim, Out_dim, H, W) and in Convolution (Out_dim, In_dim, H, W). So that was the reason I swap the axes.
Apparently, in Equinox, kernel weights in Convolution Transpose and Convolution are both in (In_dim, Out_dim, H, W).
After reading #728, I just need to flip the kernel's weight of either Convolution Transpose or Convolution to achieve similar behavior like in Pytorch and Keras with Jax.
I was trying to port the Equinox Convolution Transpose weight to Pytorch Convolution Transpose. After flipping and swapping axes of Equinox kernel's weight, Pytorch Convolution Transpose worked just like Equinox Convolution Transpose; Before that I was only swapping axes of Equinox kernel's weight and transferring the weight to Pytorch and Pytorch did not give the same output like Equinox Convolution Transpose.
Thank you @lockwo for taking the time to help me.