ctargon / TSPG

Transcriptome State Perturbation Generator

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Incorporate target distribution sampling into TF graph

bentsherman opened this issue · comments

The AdvGAN computes a "target distribution loss", which compares the perturbed output to samples from the target distribution. These target samples are currently computed via numpy (np.random.normal and np.random.multivariate_normal), however there are equivalent methods in tensorflow:

https://www.tensorflow.org/api_docs/python/tf/random/normal

https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalFullCovariance

Therefore, we should be able to move these computations to TF and leverage the GPU, both for computing mean/covariance and for random normal sampling.

Some caveats:

  • For large datasets, the covariance matrix is very large, and the matrix sqrt is the same size. Some GPUs might be able to hold these matrices and the advgan model at the same time. I am refactoring the advgan code to compute target samples once instead of once per epoch, so maybe TF will be smart enough to discard the covariance matrix and sqrt matrix immediately after using them.

  • It looks like TF's version of multivariate_normal uses cholesky decomposition, which is ideal for covaraince matrices. However, cholesky requires the matrix to be positive definite, whereas covariance matrices are only guaranteed to be positive semi-definite. In other words, if the input GEM contains genes that are constant or nearly-constant, cholesky will fail. I think you can get around this problem by filtering out genes with stddev below some threshold, but I'm not sure what that threshold should be. I would look at how scipy's cholesky function handles this check.

Done in recent commits. Target mean and covariance are computed via Numpy, the rest is handled by TF.