fattorib / Flax-ResNets

CIFAR10 ResNets implemented in JAX+Flax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ResNets: JAX+Flax vs. PyTorch

Addendum (April 2022): Coming back to this repo now, I realize that this project misses a few of the key ideas that make Jax so much more interesting than PyTorch or Tensorflow, mainly being the vmap and pmap transformations. I'd like to come back to another Jax project in the future when I have some free time!

This is a full implementation in both JAX+Flax and PyTorch of the CIFAR10 ResNets from Deep Residual Learning for Image Recognition by He et. al. This is my first project in JAX so I rewrote an older project of mine that I had originally written in PyTorch.

Both models are in the "Models" folder.

To train a ResNet20 in Flax run:

python main_flax.py --workers 4 --epochs 180 --batch-size 128 --weight-decay 1e-4 --model ResNet20 --CIFAR10 True

To train a ResNet20 in PyTorch run:

python main_torch.py --workers 4 --epochs 180 --batch-size 128 --weight-decay 1e-4 --model ResNet20 --CIFAR10 True

The following is an overview of the main ideas I learned while working with Flax:

Contents

Model Construction

Defining Modules

Through Flax's Linen API, we should be able to define modules with the @nn.compact decorator. I found writing modules this way to be very simple! For a basic residual block in Flax, we would write:

class ResidualBlock(nn.Module):
    # Define collection of datafields here
    in_channels: int

    # For batchnorm, we can pass it as a ModuleDef
    norm: ModuleDef

    # dtype for fp16/32 training
    dtype: dtypedef = jnp.float32

    # define init for conv layers
    kernel_init: Callable = nn.initializers.kaiming_normal()

    @nn.compact
    def __call__(self, x):
        residual = x

        x = nn.Conv(
            kernel_size=(3, 3),
            strides=1,
            features=self.in_channels,
            padding="SAME",
            use_bias=False,
            kernel_init=self.kernel_init,
            dtype=self.dtype,
        )(x)
        x = self.norm()(x)
        x = nn.relu(x)
        x = nn.Conv(
            kernel_size=(3, 3),
            strides=1,
            features=self.in_channels,
            padding="SAME",
            use_bias=False,
            kernel_init=self.kernel_init,
            dtype=self.dtype,
        )(x)
        x = self.norm()(x)

        x = x + residual

        return nn.relu(x)

To do the same thing in PyTorch, we would write:

class ResidualBlock(nn.Module):
    # One full block of a given filter size
    def __init__(self, in_filters):
        super(ResidualBlock, self).__init__()
        self.in_filters = in_filters
        self.conv_block = nn.Sequential(
            nn.Conv2d(
                self.in_filters,
                self.in_filters,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(self.in_filters),
            nn.ReLU(),
            nn.Conv2d(
                self.in_filters,
                self.in_filters,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(self.in_filters),
        )

        #Requires _weights_init function
        self.apply(_weights_init)

    def forward(self, x):
        residual = x
        x = self.conv_block(x)
        x += residual
        return F.relu(x)

While it felt awkward at the start, using Linen's API leads to shorter module definitions and easier-to-follow forward pass code. Something else to note is that by default, Flax doesn't have a Sequential constructor like nn.Sequential in PyTorch. While it can be added easily, I found myself not needing in Flax despite relying on it a lot in PyTorch.

Train/Test Behaviour + State

ResNets employ Batch Normalization following convolutional layers. The BatchNorm layer is interesting as it:

  • Has trainable parameters ($\alpha$ and $\beta$) and non-trainable variables (batch statistics)
  • Has different train and test behaviour

Because of this, special care is required when implementing BatchNorm layers. First for the trainable and non-trainable parameters, we handle these in the model initialization. Calling the model.init(*) method returns a PyTree of all parameters. Since the BatchNorm parameters are the only non-trainable parameters, we can split them off as follows:

...
variables = model.init(rng, jnp.ones(input_shape))
params, batch_stats = variables["params"], variables["batch_stats"]
...

Managing train/eval behaviour is done by first adding a train bool to the __call__ method of the main model (in this case the ResNet module), next we can initialize a partial module for a BatchNorm layer and pass it to all the necessary submodules. The following is a small section of the ResNet code:

...
@nn.compact
def __call__(self, x, train):

    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.1,
        epsilon=1e-5,
        dtype=jnp.float32,
    )
    ...
    x = ResidualBlock(
                in_channels=16, norm=norm, dtype=jnp.float32
    )(x)
    ...

The final step is to add arugments to the model's .apply() method as follows:

#Training 
logits, new_state = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            batch,
            mutable=["batch_stats"],
            train=True,
        )

#Evaluation - Use running mean of batch statistics
logits = state.apply_fn(
        {"params": state.params, "batch_stats": state.batch_stats},
        batch,
        mutable=False,
        train=False,
    )

Data Loading

In JAX/Flax, we can actually take the existing PyTorch data pipeline and modify it slightly to return jnp arrays instead of PyTorch Tensors. See here for more details. PyTorch's data loading and augmentations capabilities are great so being able to directly use this with Flax is great.

One issue I noticed was that my code would always return an OOM error if I set pin_memory=True in the dataloader. I suspect this is because JAX, by default will allocate tensors directly to the GPU memory, instead of the pinned memory.

Model Training

TrainState

In Flax, all model training is passed through a TrainState class which holds the .apply() method, the optimizer, the model paramaters and any other attributes we wish to include. In the Flax example, I have created a subclass of TrainState and included batch statistics, weight decay, and dynamic scaling as extra attributes.

Optimizers

The Flax docs recommend using Optax for optimizers and learning rate scheduling. By default, only the AdamW optimizer includes a weight decay parameter.

Weight decay/L2 regularization can get a bit tricky depending on the optimizer used (See Adam vs. AdamW). In our case, with SGD, we can add an L2 regularization term manually to our loss function. It is common practice to exclude certain paramaters from regulurization, including $\alpha$ and $\beta$ in BatchNorm layers and bias terms in Dense\Linear layers.

In PyTorch, we can filter these paramaters through model.named_parameters():

#Taken from timm (https://github.com/rwightman/pytorch-image-models)
for key, value in model.named_parameters():

if "fc.bias" in key or "bias" in key or "bn" in key:
    #exclude params for weight decay
else:
    #include params for weight decay

In Flax, we can use the following section of code:

weight_decay_params_filter = flax.traverse_util.ModelParamTraversal(
        lambda path, _: ("bias" not in path and "scale" not in path)
)

weight_decay_params = weight_decay_params_filter.iterate(params)

Adding a learning rate schedule is quite easy. Optax supports many of the common ones. Since the schdule is passed in as a function to the optimizer, all lr steps are handled internally compared with PyTorch which requires calling scheduler.step() manually.

Other Helpful resources

About

CIFAR10 ResNets implemented in JAX+Flax


Languages

Language:Python 100.0%