Add dropout to mitigate overfitting
velezbeltran opened this issue · comments
Description
Hello! Thanks for all of the work in this library. I was wondering if there was a straightforward way of including dropout for some of the neural networks for conditional flows. I am currently implementing an application that uses them and I am having some trouble dealing with overfitting.
Implementation
Not sure exactly how to go about it here.
Hello @velezbeltran,
at the moment, you have to define a new flow to do that.
That you can do by inheriting a current flow implementation and overwrite the MLP
in it.
As an example, let's take an altered coupling flow
class MyGeneralCouplingTransform(GeneralCouplingTransform):
def __init__(
self,
features: int,
context: int = 0,
mask: BoolTensor = None,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: Sequence[Size] = ((), ()),
**kwargs,
):
super().__init__()
# Univariate transformation
self.univariate = univariate
self.shapes = shapes
self.total = sum(prod(s) for s in shapes)
# Mask
self.register_buffer('mask', None)
if mask is None:
self.mask = torch.arange(features) % 2 == 1
else:
self.mask = mask
features_a = self.mask.sum().item()
features_b = features - features_a
# Hyper network
self.hyper = MyMLP(features_a + context, features_b * self.total, **kwargs)
I copied the whole __init__
method and just exchanged the MLP
.
After that you just need a define a flow the same way.
class MyNICE(NICE):
def __init__(
self,
features: int,
context: int = 0,
transforms: int = 3,
randmask: bool = False,
**kwargs,
):
temp = []
for i in range(transforms):
if randmask:
mask = torch.randperm(features) % 2 == i % 2
else:
mask = torch.arange(features) % 2 == i % 2
temp.append(
MyGeneralCouplingTransform(
features=features,
context=context,
mask=mask,
**kwargs,
)
)
base = Unconditional(
DiagNormal,
torch.zeros(features),
torch.ones(features),
buffer=True,
)
super().__init__(temp, base)
We could make this process easier by adding generating functions to all classes, so that someone only have to overwrite these functions and not the whole __init__
methods.
What do you think @francois-rozet?
@velezbeltran, I hope that helps you?
Best
Simon
Hello @velezbeltran 👋 As @simonschnake said, it is not currently possible to use dropout in pre-built flows, but you can build your own transformations with dropout in the hyper-network. By the way, I think you can simplify MyGeneralCouplingTransform
as
class MyGeneralCouplingTransform(GeneralCouplingTransform):
def __init__(
self,
features: int,
context: int = 0,
mask: BoolTensor = None,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: Sequence[Size] = ((), ()),
**kwargs,
):
super().__init__(features, context, mask, univariate, shapes, **kwargs)
self.hyper = MyMLP(self.hyper.in_features, self.hyper.out_features, **kwargs)
That said, we could easily add a dropout option to zuko.nn.MLP
, which would allow dropout in some pre-built flows (mainly NICE
). However, we cannot add dropout to MaskedMLP
and MonotonicMLP
. The reason is that a transformation must be a pure function, meaning that calling it several times with the same input should always produce the same output. Otherwise the transformation is not invertible/bijective. This is not the case with dropout (and batch normalization) which is why we don't support it.
Now, concerning overfitting, I have a few recommendations:
- If possible, get more data. That sounds obvious, but it goes a long way. If you cannot get more data, try data augmentation (e.g. add noise to
$x$ or to the context$c$ , randomly drop features in$c$ , etc.). - Try to increase the weight decay (
AdamW(lr=1e-3, weight_decay=1e-2)
). Most transformations become identities ($f(x) = x$ ) when the weight decay is large. - Limit the number of transformations. 3 autoregressive transformations is usually more than enough. In theory, if the univariate transformation is a universal function approximator, a single autoregressive transformation is as expressive as a thousand.
- Use small/shallow hyper-networks to reduce the capacity of the flow. Less capacity, less overfitting.
- Avoid neural transformations (
NAF
,UNAF
). They have a stronger tendency to overfit in my experience.
Oh wow. Thank you @francois-rozet and @simonschnake for all of this help this is extremely helpful and I really appreciate how clean it is to implement it! I will give some of these tricks I try @francois-rozet. I was trying using augmentations because I am also unfortunately dealing with discrete data and those have been really helpful. I was using NSF mostly because that is what I am used to and what has given me the best results in the past but I will try using others.
Again, many thanks for the prompt response and these comments. I really appreciate it!
You're welcome @velezbeltran. By the way, NSF
is an autoregressive flow (it is a subclass of MAF
) so what I recommend applies. You can also use NICE
(or MyNICE
to try dropout) with rational-quadratic spline transformations:
flow = NICE(
features,
context,
univariate=zuko.transforms.MonotonicRQSTransform, # spline transformation
shapes=([8], [8], [7]), # spline parameter shapes for 8 bins
)
Hey @velezbeltran, thinking back, one should actually not use dropout in coupling transformations either. If you do, the transformation is not pure anymore, and hence not bijective/invertible. Using impure hyper-networks (dropout or batch-norm) could lead to silent errors that are very hard to detect, and hence I don't think it is a good idea to provide dropout as an option.
Ah that makes sense! Now that I think about it this makes sense. Thanks for telling me. I have implemented dropout before and although it has worked I can totally see now how it would lead to issues. Again, thanks for helping so much!
You're welcome. I will be closing this issue, but feel free to open a new one or a discussion if you have other questions.