probabilists / zuko

Normalizing flows in PyTorch

Home Page:https://zuko.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add dropout to mitigate overfitting

velezbeltran opened this issue · comments

Description

Hello! Thanks for all of the work in this library. I was wondering if there was a straightforward way of including dropout for some of the neural networks for conditional flows. I am currently implementing an application that uses them and I am having some trouble dealing with overfitting.

Implementation

Not sure exactly how to go about it here.

Hello @velezbeltran,

at the moment, you have to define a new flow to do that.
That you can do by inheriting a current flow implementation and overwrite the MLP in it.

As an example, let's take an altered coupling flow

class MyGeneralCouplingTransform(GeneralCouplingTransform):
      def __init__(
        self,
        features: int,
        context: int = 0,
        mask: BoolTensor = None,
        univariate: Callable[..., Transform] = MonotonicAffineTransform,
        shapes: Sequence[Size] = ((), ()),
        **kwargs,
    ):
        super().__init__()

        # Univariate transformation
        self.univariate = univariate
        self.shapes = shapes
        self.total = sum(prod(s) for s in shapes)

        # Mask
        self.register_buffer('mask', None)

        if mask is None:
            self.mask = torch.arange(features) % 2 == 1
        else:
            self.mask = mask

        features_a = self.mask.sum().item()
        features_b = features - features_a

        # Hyper network
        self.hyper = MyMLP(features_a + context, features_b * self.total, **kwargs)

I copied the whole __init__ method and just exchanged the MLP.
After that you just need a define a flow the same way.

class MyNICE(NICE):
    def __init__(
        self,
        features: int,
        context: int = 0,
        transforms: int = 3,
        randmask: bool = False,
        **kwargs,
    ):
        temp = []

        for i in range(transforms):
            if randmask:
                mask = torch.randperm(features) % 2 == i % 2
            else:
                mask = torch.arange(features) % 2 == i % 2

            temp.append(
                MyGeneralCouplingTransform(
                    features=features,
                    context=context,
                    mask=mask,
                    **kwargs,
                )
            )

        base = Unconditional(
            DiagNormal,
            torch.zeros(features),
            torch.ones(features),
            buffer=True,
        )

        super().__init__(temp, base)

We could make this process easier by adding generating functions to all classes, so that someone only have to overwrite these functions and not the whole __init__ methods.

What do you think @francois-rozet?

@velezbeltran, I hope that helps you?

Best
Simon

Hello @velezbeltran 👋 As @simonschnake said, it is not currently possible to use dropout in pre-built flows, but you can build your own transformations with dropout in the hyper-network. By the way, I think you can simplify MyGeneralCouplingTransform as

class MyGeneralCouplingTransform(GeneralCouplingTransform):
      def __init__(
        self,
        features: int,
        context: int = 0,
        mask: BoolTensor = None,
        univariate: Callable[..., Transform] = MonotonicAffineTransform,
        shapes: Sequence[Size] = ((), ()),
        **kwargs,
    ):
        super().__init__(features, context, mask, univariate, shapes, **kwargs)

        self.hyper = MyMLP(self.hyper.in_features, self.hyper.out_features, **kwargs)

That said, we could easily add a dropout option to zuko.nn.MLP, which would allow dropout in some pre-built flows (mainly NICE). However, we cannot add dropout to MaskedMLP and MonotonicMLP. The reason is that a transformation must be a pure function, meaning that calling it several times with the same input should always produce the same output. Otherwise the transformation is not invertible/bijective. This is not the case with dropout (and batch normalization) which is why we don't support it.

Now, concerning overfitting, I have a few recommendations:

  1. If possible, get more data. That sounds obvious, but it goes a long way. If you cannot get more data, try data augmentation (e.g. add noise to $x$ or to the context $c$, randomly drop features in $c$, etc.).
  2. Try to increase the weight decay (AdamW(lr=1e-3, weight_decay=1e-2)). Most transformations become identities ($f(x) = x$) when the weight decay is large.
  3. Limit the number of transformations. 3 autoregressive transformations is usually more than enough. In theory, if the univariate transformation is a universal function approximator, a single autoregressive transformation is as expressive as a thousand.
  4. Use small/shallow hyper-networks to reduce the capacity of the flow. Less capacity, less overfitting.
  5. Avoid neural transformations (NAF, UNAF). They have a stronger tendency to overfit in my experience.

Oh wow. Thank you @francois-rozet and @simonschnake for all of this help this is extremely helpful and I really appreciate how clean it is to implement it! I will give some of these tricks I try @francois-rozet. I was trying using augmentations because I am also unfortunately dealing with discrete data and those have been really helpful. I was using NSF mostly because that is what I am used to and what has given me the best results in the past but I will try using others.

Again, many thanks for the prompt response and these comments. I really appreciate it!

You're welcome @velezbeltran. By the way, NSF is an autoregressive flow (it is a subclass of MAF) so what I recommend applies. You can also use NICE (or MyNICE to try dropout) with rational-quadratic spline transformations:

flow = NICE(
    features,
    context,
    univariate=zuko.transforms.MonotonicRQSTransform,  # spline transformation
    shapes=([8], [8], [7]),                            # spline parameter shapes for 8 bins
)

Hey @velezbeltran, thinking back, one should actually not use dropout in coupling transformations either. If you do, the transformation is not pure anymore, and hence not bijective/invertible. Using impure hyper-networks (dropout or batch-norm) could lead to silent errors that are very hard to detect, and hence I don't think it is a good idea to provide dropout as an option.

Ah that makes sense! Now that I think about it this makes sense. Thanks for telling me. I have implemented dropout before and although it has worked I can totally see now how it would lead to issues. Again, thanks for helping so much!

You're welcome. I will be closing this issue, but feel free to open a new one or a discussion if you have other questions.