probabilists / zuko

Normalizing flows in PyTorch

Home Page:https://zuko.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Generating samples with their log-probability

simonschnake opened this issue · comments

Description

There a some use-cases (at least I have one), where some also needs the ladj,
while calculating the inverse operation. For my use-case I am not only using the normalizing flow
to generate samples, but also want to know the likelihood of the produced samples.

Implementation

Implementation could be somewhat hard, because in principle every transformation would need to include another
method inverse_and_ladj. I am willing to help and contribute pull requests for that. My main focus at the moment are neural spline flows.

I'm mainly opening this issue to figure out if this is wanted, and one has to decide how to introduce the functionality in the consuming classes of the transformations.

This feature would be particularly necessary for applications like importance sampling, where samples and their probability should be computed with the least number of steps possible.

I'm not sure if I have understood enough the zuko code, but would it work just changing this line? https://github.com/francois-rozet/zuko/blob/master/zuko/distributions.py#L121 from

 def rsample(self, shape: Size = ()) -> Tensor:
        if self.base.has_rsample:
            z = self.base.rsample(shape)
        else:
            z = self.base.sample(shape)

        return self.transform.inv(z)

to

 def rsample(self, shape: Size = ()) -> Tensor:
        if self.base.has_rsample:
            z = self.base.rsample(shape)
        else:
            z = self.base.sample(shape)

        return self.transform.inv.call_and_ladj(z)  # with a proper inversion of the prob.

I would be happy to help adding this feature 👍

Hello @simonschnake and @valsdav, thank you for your interest in Zuko 🔥 This is a sensible request. I just have a few questions:

  1. Do you need the ladj of the inverse or the log_prob of your samples?

The former would require to modify all transformations while the latter might only require to modify the NormalizingFlow class.

  1. Do you need this as a convenience enhancement or as a performance improvement?

This could dramatically change the way the feature is implemented.

I would need the log_prob of the produced samples. To my understanding, the correct log_prob(y) = log_prob(x) + ladj(x), so calculating the ladj for the inverse transformation would be necessary.
If you check out other libraries like NFlows, Distrax, they compute the ladj in both directions.

I have roughly written an example using the NSF Flow. I could provide a pull request.

Thanks a lot for your help
Simon

Before submitting a PR, we need to decide how the feature should be implemented. It is a major API change so I don't want to rush it. IMHO, nflows does this in a bad way: transformations always compute their ladj, which is a waste of resources most of the time.

I like the idea of @valsdav to use the call_and_ladj of the inverse transform instead of adding a new method inverse_and_ladj. However, as is, it would not be more efficient as the call_and_ladj function of Inverse first applies _inverse and then computes log_abs_det_jacobian. Hence, it would be equivalent, in terms of performance, to a convenience method

def rsample_and_log_prob(self, shape: Size = ()) -> Tuple[Tensor, Tensor]:
    x = self.rsample(shape)
    log_p = self.log_prob(x)
    return x, log_p

My idea is to modify the inv property of ComposedTransform such that it does not return an Inverse transform, but a new ComposedTransform with its transformations reversed. With that modification, calling inv.call_and_ladj(y) should be more efficient. However, this "performance bump" will likely be negligible with respect to the cost of inverting autoregressive transformations anyway.

My proposition is something like

diff --git a/zuko/distributions.py b/zuko/distributions.py
index b975e76..82a02e1 100644
--- a/zuko/distributions.py
+++ b/zuko/distributions.py
@@ -120,6 +120,18 @@ class NormalizingFlow(Distribution):
 
         return self.transform.inv(z)
 
+    def rsample_and_log_prob(self, shape: Size = ()) -> Tuple[Tensor, Tensor]:
+        if self.base.has_rsample:
+            z = self.base.rsample(shape)
+        else:
+            z = self.base.sample(shape)
+
+        log_p = self.base.log_prob(z)
+        x, ladj = self.transform.inv.call_and_ladj(z)
+        ladj = _sum_rightmost(ladj, self.reinterpreted)
+
+        return x, log_p - ladj
+
 
 class Joint(Distribution):
     r"""Creates a distribution for a multivariate random variable :math:`X` which
diff --git a/zuko/transforms.py b/zuko/transforms.py
index e926ee6..e4c087b 100644
--- a/zuko/transforms.py
+++ b/zuko/transforms.py
@@ -107,6 +107,17 @@ class ComposedTransform(Transform):
             x = t(x)
         return x
 
+    @property
+    def inv(self):
+        new = self.__new__(ComposedTransform)
+        new.transforms = [t.inv for t in reversed(self.transforms)]
+        new.domain_dim = self.codomain_dim
+        new.codomain_dim = self.domain_dim
+
+        Transform.__init__(new)
+
+        return new
+
     def _inverse(self, y: Tensor) -> Tensor:
         for t in reversed(self.transforms):
             y = t.inv(y)

It provides a slight boost in performance, and I think it would make back-propagation slightly more stable.

That seems very elegant and nice, but I have the feeling that I can get very complicated.
We would design a second inverse class for each transformation class.

I shortly designed a fitting version of the MonotonicAffineTransform
here simonschnake/zuko@master...simonschnake-patch-1

Could be that I misunderstand the architecture, and it is a lot simpler.

Sorry, I was not clear enough. Only transformations that can profit from a faster .inv.call_and_ladj should have their inv property overwritten. Currently, this is only the case of ComposedTransform and maybe FreeFormJacobianTransform. For univariate transformations it would be useless as they are often used with AutoregressiveTransform.

Thanks a lot @francois-rozet ! 💯