cholesky factorization error on cpu
pumplerod opened this issue · comments
there seems to be some limit between the n_components and n_features. If I try and create a model with
n_components=1
n_features=99
it will fail with _LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 22 is not positive-definite).
reducing n_features=98
will work but then if I raise n_components=2
the error returns.
I am trying to work with many more features and components. Potentially 1000+ features and an unknown number of components, but it is likely to be high. Is there any workaround for this?
It appears to have something to do also with the number of samples. In my example I was using 100 samples, but if I increase that then the error goes away. I guess I need to keep my n_samples higher than n_components + n_features or something like that. Still tricky to work around.
I have the same issue. It appears when using cpu or cuda.
Seems like it appears with high n_components or n_features in comparison to samples.
Is there a possibility to fix it? Or some rule how to avoid it?
Hi there,
It has been several months, but maybe changing the dtype
from PyTorch's default float32
to float64
(double) would help in some cases.
I think the issue comes from some eigenvalues of the var
here close to zero, so the Cholesky factorization will have some numerical issues.
Line 299 in 23eaf64
So changing to `double' will help to alleviate the numerical issues.
I change self.mu
, self.var
, self.pi
to double
in _init_params
by adding .double()
after their initializations.
e.g.,
self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features) requires_grad=False)
==>
self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features).double(), requires_grad=False)
and make sure the input of function fit()
is also in double
.
However, this is a kind of temporary trick and indeed increases the running time and memory occupation.
Hope this helps.