google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax counterpart for `torch.nn.Conv1d`

Liyang90 opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): GCP Cloud TPU VM
  • Flax 0.8.5, jax 0.4.31, jaxlib 0.4.31
  • Python version: 3.10
  • GPU/TPU model and memory: v5p
  • CUDA version (if applicable):

Problem you have encountered:

I'm trying to figure out the flax counterpart for torch.nn.Conv1d. But I find the implementation below have same output but different grads after backward.

conv_torch.py:

from torch import nn


class BlockTorch(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1d = nn.Conv1d(
            in_channels=5120,
            out_channels=5120,
            bias=True,
            kernel_size=4,
            groups=5120,
            padding=3,
        )

    def forward(
        self,
        x,
    ):
        batch_size, seq_len, _ = x.shape
        x = self.conv1d(x.transpose(1, 2))[..., :seq_len]
        return x.transpose(1, 2)

conv_jax.py

from flax import linen as nn
import jax
import jax.numpy as jnp


class BlockJAX_0(nn.Module):
  kernel: jax.Array
  bias: jax.Array

  def setup(self):

    def kernel_init(key, shape, dtype):
      assert self.kernel.shape == shape
      return self.kernel.astype(dtype)
    
    def bias_init(key, shape, dtype):
      assert self.bias.shape == shape
      return self.bias.astype(dtype)

    self.conv1d = nn.Conv(features=5120,
                          kernel_size=[4],
                          feature_group_count=5120,
                          padding='CAUSAL',
                          use_bias=True,
                          kernel_init=kernel_init,
                          bias_init=bias_init,
                          )
    
  def __call__(self, x):
    x = self.conv1d(x)
    return x
  

class BlockJAX_1(nn.Module):
  kernel: jax.Array
  bias: jax.Array

  def setup(self):

    def kernel_init(key, shape, dtype):
      assert self.kernel.shape == shape
      return self.kernel.astype(dtype)
    
    def bias_init(key, shape, dtype):
      assert self.bias.shape == shape
      return self.bias.astype(dtype)

    self.conv1d = nn.Conv(features=5120,
                          kernel_size=[4],
                          feature_group_count=5120,
                          padding=3,
                          use_bias=True,
                          kernel_init=kernel_init,
                          bias_init=bias_init,
                          )
    
  def __call__(self, x):
    (b, l, d) = x.shape
    x = self.conv1d(x)[:, :l, :]
    return x

My main test script test.py

import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
rs = RandomState(MT19937(SeedSequence(123456789)))

import jax
import jax.numpy as jnp

import torch

from conv_jax import BlockJAX_0, BlockJAX_1
from conv_torch import BlockTorch

# prepare common weights and inputs
kernel = rs.normal(size=(4, 1, 5120))
bias = rs.normal(size=(5120,))
input = rs.normal(size=(4, 4096, 5120))

# torch module forward and backward
torch.set_printoptions(precision=7)
conv_torch = BlockTorch()
state_dict = conv_torch.state_dict()
state_dict["conv1d.weight"] = torch.from_numpy(kernel).to(torch.float32).transpose(0, 2)
state_dict["conv1d.bias"] = torch.from_numpy(bias).to(torch.float32)
conv_torch.load_state_dict(state_dict)
conv_torch.zero_grad()
output_torch = conv_torch(torch.from_numpy(input).to(torch.float32))
loss_torch = output_torch.mean()
loss_torch.backward()

# flax module forward and backward
def jax_forward_backward(model, params, input):
  def forward(params, input):
    output = model.apply(params, input)
    loss = jnp.mean(output)
    return loss, output

  forward_backward_fn = jax.value_and_grad(forward, has_aux=True)
  (loss, output), grad = forward_backward_fn(params, input)
  return loss, output, grad

input_jax = jnp.array(input)
kernel_jax = jnp.array(kernel)
bias_jax = jnp.array(bias)

conv_jax_0 = BlockJAX_0(kernel_jax, bias_jax)
rng = jax.random.key(0)
params_jax_0 = conv_jax_0.init(rng, input_jax)

loss_jax_0, output_jax_0, grad_jax_0 = jax_forward_backward(conv_jax_0, params_jax_0, input_jax)

conv_jax_1 = BlockJAX_1(kernel_jax, bias_jax)
params_jax_1 = conv_jax_1.init(rng, input_jax)

loss_jax_1, output_jax_1, grad_jax_1 = jax_forward_backward(conv_jax_1, params_jax_1, input_jax)

print("================================================")
print(f"conv_torch.conv1d.weight.grad.T shape: {conv_torch.conv1d.weight.grad.T.shape}")
print(conv_torch.conv1d.weight.grad.T)
print("================================================")
print(f'grad_jax_0["params"]["conv1d"]["kernel"] {grad_jax_0["params"]["conv1d"]["kernel"].shape}')
print(grad_jax_0["params"]["conv1d"]["kernel"])
print("================================================")

def wmape(a, b):
  return np.sum(np.abs(a - b)) / np.sum(np.abs(a))

print(f"losses: {(loss_torch, loss_jax_0, loss_jax_1)}")

print("Outputs WMAPE:")
print(wmape(output_torch.detach().numpy(), np.array(output_jax_0)))
print(wmape(output_torch.detach().numpy(), np.array(output_jax_1)))

print("Grads WMAPE:")
print(wmape(conv_torch.conv1d.weight.grad.T.detach().numpy(), np.array(grad_jax_0["params"]["conv1d"]["kernel"])))
print(wmape(conv_torch.conv1d.weight.grad.T.detach().numpy(), np.array(grad_jax_1["params"]["conv1d"]["kernel"])))

All the modules in the test script loads same conv weight and bias. The output:

conv_torch.conv1d.weight.grad.T shape: torch.Size([4, 1, 5120])
tensor([[[-1.0676654e-06, -6.8133699e-07,  6.4320474e-08,  ...,
          -6.2070239e-07, -1.3589821e-06, -4.0437203e-06]],

        [[-1.0596159e-06, -6.6314635e-07,  6.1584302e-08,  ...,
          -6.3736798e-07, -1.4141298e-06, -4.0575105e-06]],

        [[-1.0563341e-06, -6.9564146e-07,  1.6096948e-08,  ...,
          -6.5698919e-07, -1.4593004e-06, -4.0748873e-06]],

        [[-1.0387610e-06, -6.8761653e-07,  7.2919590e-09,  ...,
          -6.7382234e-07, -1.4510101e-06, -4.0458808e-06]]])
================================================
grad_jax_0["params"]["conv1d"]["kernel"] (4, 1, 5120)
[[[-1.0676664e-06 -6.8133664e-07  6.4321057e-08 ...  5.7032690e-07
    3.1477473e-06  1.3152874e-06]]

 [[-1.0596164e-06 -6.6314561e-07  6.1584501e-08 ...  5.5595456e-07
    3.1550721e-06  1.3070267e-06]]

 [[-1.0563347e-06 -6.9564129e-07  1.6097601e-08 ...  6.2066499e-07
    3.1343752e-06  1.2927190e-06]]

 [[-1.0387618e-06 -6.8761614e-07  7.2924422e-09 ...  5.8497437e-07
    3.1675968e-06  1.2691942e-06]]]
================================================

losses: (tensor(0.0189537, grad_fn=<MeanBackward0>), Array(0.01895385, dtype=float32), Array(0.01895385, dtype=float32))

Outputs WMAPE:
0.0013432346
0.0013432346

Grads WMAPE:
0.7058088
0.7058088

It shows small error between conv layer outputs, but big difference between grads.

What you expected to happen:

The WMAPE of both outputs and grads should be small.

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.