Compatibility with Torch LSTM
mttga opened this issue · comments
Currently Torch definition of a LSTM cell (https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) is:
However the LSTM cell definition in FLAX (https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.LSTMCell.html) is:
You can see that the biases
class LSTMCell(RNNCellBase):
features: int
gate_fn: Callable[..., Any] = sigmoid
activation_fn: Callable[..., Any] = tanh
kernel_init: Initializer = default_kernel_init
recurrent_kernel_init: Initializer = initializers.orthogonal()
bias_init: Initializer = initializers.zeros_init()
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
carry_init: Initializer = initializers.zeros_init()
+ bias_all: bool = False
@compact
def __call__(self, carry, inputs):
c, h = carry
hidden_features = h.shape[-1]
dense_h = partial(
Dense,
features=hidden_features,
use_bias=True,
kernel_init=self.recurrent_kernel_init,
bias_init=self.bias_init,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
dense_i = partial(
Dense,
features=hidden_features,
- use_bias=False,
+ use_bias=self.bias_all,
kernel_init=self.kernel_init,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h))
f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h))
g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h))
o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h))
new_c = f * c + i * g
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h
I've tested it, and with this simple change, bringing the Torch weights to Flax becomes straightforward and works perfectly. I'd be happy to open a pull request if you think this feature would be beneficial.
Hey @mttga, you can always add Torch's 2 biases to get Flax's single bias:
flax_b_hi = torch_b_ii + torch_b_hi
flax_b_hf = torch_b_if + torch_b_hf
🤦♂️