ldeecke / gmm-torch

Gaussian mixture models in PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

cholesky factorization error on cpu

pumplerod opened this issue · comments

there seems to be some limit between the n_components and n_features. If I try and create a model with

n_components=1
n_features=99

it will fail with _LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 22 is not positive-definite).

reducing n_features=98 will work but then if I raise n_components=2 the error returns.

I am trying to work with many more features and components. Potentially 1000+ features and an unknown number of components, but it is likely to be high. Is there any workaround for this?

It appears to have something to do also with the number of samples. In my example I was using 100 samples, but if I increase that then the error goes away. I guess I need to keep my n_samples higher than n_components + n_features or something like that. Still tricky to work around.

I have the same issue. It appears when using cpu or cuda.
Seems like it appears with high n_components or n_features in comparison to samples.

Is there a possibility to fix it? Or some rule how to avoid it?

Hi there,
It has been several months, but maybe changing the dtype from PyTorch's default float32 to float64(double) would help in some cases.
I think the issue comes from some eigenvalues of the var here close to zero, so the Cholesky factorization will have some numerical issues.

log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0,k]))).sum()

So changing to `double' will help to alleviate the numerical issues.

I change self.mu, self.var, self.pi to double in _init_params by adding .double() after their initializations.
e.g.,
self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features) requires_grad=False) ==>
self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features).double(), requires_grad=False)

and make sure the input of function fit() is also in double.

However, this is a kind of temporary trick and indeed increases the running time and memory occupation.
Hope this helps.