google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Compatibility with Torch LSTM

mttga opened this issue · comments

Currently Torch definition of a LSTM cell (https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) is:

$$ \begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array} $$

However the LSTM cell definition in FLAX (https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.LSTMCell.html) is:

$$ \begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split} $$

You can see that the biases $b_{ii}, b_{if}, b_{ig}, b_{io}$ are missing in Flax. I understand the choice. However, it would be nice to make it possible to decide if using these biases or not. In this way, we can ensure that the weights of a pretrained Torch LSTM Cell can be converted directly into Flax (which can be very useful for a number of applications). The change would be super simple to implement:

class LSTMCell(RNNCellBase):
    features: int
    gate_fn: Callable[..., Any] = sigmoid
    activation_fn: Callable[..., Any] = tanh
    kernel_init: Initializer = default_kernel_init
    recurrent_kernel_init: Initializer = initializers.orthogonal()
    bias_init: Initializer = initializers.zeros_init()
    dtype: Optional[Dtype] = None
    param_dtype: Dtype = jnp.float32
    carry_init: Initializer = initializers.zeros_init()
+   bias_all: bool = False 

    @compact
    def __call__(self, carry, inputs):

        c, h = carry
        hidden_features = h.shape[-1]
        dense_h = partial(
          Dense,
          features=hidden_features,
          use_bias=True,
          kernel_init=self.recurrent_kernel_init,
          bias_init=self.bias_init,
          dtype=self.dtype,
          param_dtype=self.param_dtype,
        )
        dense_i = partial(
          Dense,
          features=hidden_features,
-         use_bias=False,
+         use_bias=self.bias_all,
          kernel_init=self.kernel_init,
          dtype=self.dtype,
          param_dtype=self.param_dtype,
        )
        i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h))
        f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h))
        g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h))
        o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h))
        new_c = f * c + i * g
        new_h = o * self.activation_fn(new_c)
        return (new_c, new_h), new_h

I've tested it, and with this simple change, bringing the Torch weights to Flax becomes straightforward and works perfectly. I'd be happy to open a pull request if you think this feature would be beneficial.

Hey @mttga, you can always add Torch's 2 biases to get Flax's single bias:

flax_b_hi = torch_b_ii + torch_b_hi
flax_b_hf = torch_b_if + torch_b_hf

🤦‍♂️