flax counterpart for `torch.nn.Conv1d`
Liyang90 opened this issue · comments
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): GCP Cloud TPU VM
- Flax 0.8.5, jax 0.4.31, jaxlib 0.4.31
- Python version: 3.10
- GPU/TPU model and memory: v5p
- CUDA version (if applicable):
Problem you have encountered:
I'm trying to figure out the flax counterpart for torch.nn.Conv1d
. But I find the implementation below have same output but different grads after backward.
conv_torch.py
:
from torch import nn
class BlockTorch(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(
in_channels=5120,
out_channels=5120,
bias=True,
kernel_size=4,
groups=5120,
padding=3,
)
def forward(
self,
x,
):
batch_size, seq_len, _ = x.shape
x = self.conv1d(x.transpose(1, 2))[..., :seq_len]
return x.transpose(1, 2)
conv_jax.py
from flax import linen as nn
import jax
import jax.numpy as jnp
class BlockJAX_0(nn.Module):
kernel: jax.Array
bias: jax.Array
def setup(self):
def kernel_init(key, shape, dtype):
assert self.kernel.shape == shape
return self.kernel.astype(dtype)
def bias_init(key, shape, dtype):
assert self.bias.shape == shape
return self.bias.astype(dtype)
self.conv1d = nn.Conv(features=5120,
kernel_size=[4],
feature_group_count=5120,
padding='CAUSAL',
use_bias=True,
kernel_init=kernel_init,
bias_init=bias_init,
)
def __call__(self, x):
x = self.conv1d(x)
return x
class BlockJAX_1(nn.Module):
kernel: jax.Array
bias: jax.Array
def setup(self):
def kernel_init(key, shape, dtype):
assert self.kernel.shape == shape
return self.kernel.astype(dtype)
def bias_init(key, shape, dtype):
assert self.bias.shape == shape
return self.bias.astype(dtype)
self.conv1d = nn.Conv(features=5120,
kernel_size=[4],
feature_group_count=5120,
padding=3,
use_bias=True,
kernel_init=kernel_init,
bias_init=bias_init,
)
def __call__(self, x):
(b, l, d) = x.shape
x = self.conv1d(x)[:, :l, :]
return x
My main test script test.py
import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
rs = RandomState(MT19937(SeedSequence(123456789)))
import jax
import jax.numpy as jnp
import torch
from conv_jax import BlockJAX_0, BlockJAX_1
from conv_torch import BlockTorch
# prepare common weights and inputs
kernel = rs.normal(size=(4, 1, 5120))
bias = rs.normal(size=(5120,))
input = rs.normal(size=(4, 4096, 5120))
# torch module forward and backward
torch.set_printoptions(precision=7)
conv_torch = BlockTorch()
state_dict = conv_torch.state_dict()
state_dict["conv1d.weight"] = torch.from_numpy(kernel).to(torch.float32).transpose(0, 2)
state_dict["conv1d.bias"] = torch.from_numpy(bias).to(torch.float32)
conv_torch.load_state_dict(state_dict)
conv_torch.zero_grad()
output_torch = conv_torch(torch.from_numpy(input).to(torch.float32))
loss_torch = output_torch.mean()
loss_torch.backward()
# flax module forward and backward
def jax_forward_backward(model, params, input):
def forward(params, input):
output = model.apply(params, input)
loss = jnp.mean(output)
return loss, output
forward_backward_fn = jax.value_and_grad(forward, has_aux=True)
(loss, output), grad = forward_backward_fn(params, input)
return loss, output, grad
input_jax = jnp.array(input)
kernel_jax = jnp.array(kernel)
bias_jax = jnp.array(bias)
conv_jax_0 = BlockJAX_0(kernel_jax, bias_jax)
rng = jax.random.key(0)
params_jax_0 = conv_jax_0.init(rng, input_jax)
loss_jax_0, output_jax_0, grad_jax_0 = jax_forward_backward(conv_jax_0, params_jax_0, input_jax)
conv_jax_1 = BlockJAX_1(kernel_jax, bias_jax)
params_jax_1 = conv_jax_1.init(rng, input_jax)
loss_jax_1, output_jax_1, grad_jax_1 = jax_forward_backward(conv_jax_1, params_jax_1, input_jax)
print("================================================")
print(f"conv_torch.conv1d.weight.grad.T shape: {conv_torch.conv1d.weight.grad.T.shape}")
print(conv_torch.conv1d.weight.grad.T)
print("================================================")
print(f'grad_jax_0["params"]["conv1d"]["kernel"] {grad_jax_0["params"]["conv1d"]["kernel"].shape}')
print(grad_jax_0["params"]["conv1d"]["kernel"])
print("================================================")
def wmape(a, b):
return np.sum(np.abs(a - b)) / np.sum(np.abs(a))
print(f"losses: {(loss_torch, loss_jax_0, loss_jax_1)}")
print("Outputs WMAPE:")
print(wmape(output_torch.detach().numpy(), np.array(output_jax_0)))
print(wmape(output_torch.detach().numpy(), np.array(output_jax_1)))
print("Grads WMAPE:")
print(wmape(conv_torch.conv1d.weight.grad.T.detach().numpy(), np.array(grad_jax_0["params"]["conv1d"]["kernel"])))
print(wmape(conv_torch.conv1d.weight.grad.T.detach().numpy(), np.array(grad_jax_1["params"]["conv1d"]["kernel"])))
All the modules in the test script loads same conv
weight and bias. The output:
conv_torch.conv1d.weight.grad.T shape: torch.Size([4, 1, 5120])
tensor([[[-1.0676654e-06, -6.8133699e-07, 6.4320474e-08, ...,
-6.2070239e-07, -1.3589821e-06, -4.0437203e-06]],
[[-1.0596159e-06, -6.6314635e-07, 6.1584302e-08, ...,
-6.3736798e-07, -1.4141298e-06, -4.0575105e-06]],
[[-1.0563341e-06, -6.9564146e-07, 1.6096948e-08, ...,
-6.5698919e-07, -1.4593004e-06, -4.0748873e-06]],
[[-1.0387610e-06, -6.8761653e-07, 7.2919590e-09, ...,
-6.7382234e-07, -1.4510101e-06, -4.0458808e-06]]])
================================================
grad_jax_0["params"]["conv1d"]["kernel"] (4, 1, 5120)
[[[-1.0676664e-06 -6.8133664e-07 6.4321057e-08 ... 5.7032690e-07
3.1477473e-06 1.3152874e-06]]
[[-1.0596164e-06 -6.6314561e-07 6.1584501e-08 ... 5.5595456e-07
3.1550721e-06 1.3070267e-06]]
[[-1.0563347e-06 -6.9564129e-07 1.6097601e-08 ... 6.2066499e-07
3.1343752e-06 1.2927190e-06]]
[[-1.0387618e-06 -6.8761614e-07 7.2924422e-09 ... 5.8497437e-07
3.1675968e-06 1.2691942e-06]]]
================================================
losses: (tensor(0.0189537, grad_fn=<MeanBackward0>), Array(0.01895385, dtype=float32), Array(0.01895385, dtype=float32))
Outputs WMAPE:
0.0013432346
0.0013432346
Grads WMAPE:
0.7058088
0.7058088
It shows small error between conv layer outputs, but big difference between grads.
What you expected to happen:
The WMAPE of both outputs and grads should be small.
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.