google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adding a new transposed convolution function (similar to `torch.nn.ConvTranspose2d()`)

andsteing opened this issue · comments

Adding a transposed convolution as proposed in
jax-ml/jax#5772
would also be very useful when porting models from PyTorch to Flax (as in #1848).

Hey is this an active issue that is being worked on ?

We actually have an implementation already: https://github.com/google/flax/blob/main/flax/linen/linear.py#L447

I don't really understand why we have this issue.

@andsteing can you please clarify?

According to the docs :-

torch.nn.ConvTranspose2d and nn.ConvTranspose are not compatible. nn.ConvTranspose is a wrapper around jax.lax.conv_transpose which computes a fractionally strided convolution, while torch.nn.ConvTranspose2d computes a gradient based transposed convolution.

@marcvanzee this might be the reason and also why i created this issue.

Ahh sorry, didn't see that, thanks for noting! No it isn't being worked on, do you want to work on it?

Yes, I will be interested to work on it.

I just noticed this issue has the "blocked" label. @jheek could you please explain this? I suppose it is blocked on the JAX issue jax-ml/jax#5772?

@codeboy5 In that case I guess we have to wait with working on this issue until that one is merged.

Oh okay. I looked at that issue too, hasn't been any new updates for a year.
Thanks

Hmm I see, maybe you can reply to that issue and ask whether they are planning to merge it soon? Otherwise you could ask them if you can pick up that issue if you are really interested!