google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PReLU activation implementation

isaaccorley opened this issue · comments

I wanted to gauge interest on adding a PReLU activation. I noticed that flax.linen.activations are simply aliasing jax.nn activation functions which also doesn't have a PReLU implementation.

To add some background, PReLU is simply Leaky ReLU where the alpha (slope) parameter is trainable and not fixed. This makes it simple to implement as a Module if desired.

Here's an example implementation from another project of mine.

from functools import partial
from typing import Any, Sequence

import jax.numpy as jnp
import flax.linen as nn


# This is nearly identical to jnp.ones however multiplies the output of jnp.ones by the constant value
def constant(key, shape: Sequence[int], value: Any, dtype: Any = jnp.float32) -> jnp.ndarray:
    value = jnp.asarray(value, dtype)
    return jnp.ones(shape, dtype) * value


class PReLU(nn.Module):
    negative_slope_init: float = 0.01
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = jnp.asarray(x, self.dtype)
        negative_slope = self.param(
            "negative_slope",
            partial(constant, value=self.negative_slope_init, dtype=self.dtype),
            (1,)
        )
        return jnp.where(x >= 0, x, negative_slope * x)

Given that all current activation functions reside in JAX, it seem more fitting to add this JAX. Do you want to file an issue against their repo?

Thanks for the suggestion. The main reason I filed the issue here was because it seems like PReLU is a special case where it has a trainable param and, if I'm not mistaken, all other jax activations do not.

I'm not sure if this changes your suggestion, but it's something to consider.

@isaaccorley - hey so sorry for the slow feedback on your suggestion here.

2 points:

  • instead of defining a constant init func, we can just declare a jnp scalar array of the correct dtype.
  • I think an -activation- "function" should strictly follow the dtype of its argument, so no dtype attribute, just derive it from x

So what if we added something like this?

class PReLU(nn.Module):
    negative_slope_init: float = 0.01
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        negative_slope = self.param(
            "negative_slope",
            lambda k: jnp.array(self.negative_slope_init, x.dtype)
        )
        return jnp.where(x >= 0, x, negative_slope * x)

I'm indifferent on the implementation. I think the only thing to point out would be since we are inheriting from Module and other Modules have a dtype param, should we stray from that standard even though it is an activation function?

I created a constant init func because jax itself seemed to be lacking one, however I haven't received a response to the issue I posted in the jax repo requesting to add it so I'm fine with just using a lambda.

  • Other Modules have a dtype param to control the precision of their -intermediate- values, and a simple activation function like this doesn't have intermediates. We don't require modules to surface a dtype= attribute - it's just convention for the core layers to do so to give users the ability to control the floating-point types of the "insides"

  • The "constant" functions you're looking for already exist: jnp.full and jnp.full_like

  1. Makes sense thanks for clarifying that.
  2. Thanks for pointing me jnp.full. I wasn't aware of that.

Shall I make a PR then?

Yeah if you'd like to make a PR we could add the above to activations.py I think (after all the passthrough function imports). (but no pressure - if you don't have time we can add it soon ourselves.)

I'll try to take a first stab at it since it will be my first time contributing to flax.

The current implementation of PReLU does not work as the other activation functions.

The following example code raises an error at initialization

class MLP(nn.Module):
    """ Definition of MLP 

    attributes
    ----------
    :param hidden_sizes: list of int corresponding to the number of neurons in each hidden layer.
    :param output_size: int corresponding to the number of neurons in the output layer.
    """
    hidden_sizes: Sequence[int]
    out_size: int
    
    @nn.compact
    def __call__(self, x, **kwargs):
        name = kwargs.pop('name', 'fc')
        for e, size in enumerate(self.hidden_sizes, 1):
            x = nn.Dense(size, name=name + str(e))(x)
            # x = nn.silu(x)
            x = nn.PReLU(x)
        x = nn.Dense(self.out_size, name=name + '_output')(x)
        return x 
    186 @compact
    187 def __call__(self, inputs: Array) -> Array:
    188   """Applies a linear transformation to the inputs along the last dimension.
    189 
    190   Args:
   (...)
    194     The transformed input.
    195   """
    196   kernel = self.param('kernel',
    197                       self.kernel_init,
--> 198                       (jnp.shape(inputs)[-1], self.features),
    199                       self.param_dtype)
    200   if self.use_bias:
    201     bias = self.param('bias', self.bias_init, (self.features,),
    202                       self.param_dtype)

IndexError: tuple index out of range

PReLU is not following the same definitions as the other activations.

In the previous example, it needs to be

x = nn.PReLU()(x) 

The documentation was not obvious to me.

Because PReLU initializes a trained scalar parameter, it has to be treated as a layer. I've added clarification and example usage to the docs in #3122