cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Missing potential `unsqueeze` for `initial_guess` in `linear_cg`

tvercaut opened this issue · comments

If initial_guess is a vector, linear_cg doesn't do the right thing which leads to CUDA being out of memory. Similar to

is_vector = rhs.ndimension() == 1
if is_vector:
rhs = rhs.unsqueeze(-1)

a check for the need to unsqueeze initial_guess should probably be performed here:
if initial_guess is None:
initial_guess = torch.zeros_like(rhs)

Steps to reproduce the bug on colab:

import sys
print("Python version:",sys.version)

import torch
print('PyTorch version:',torch.__version__)

torchdevice = torch.device('cpu')
if torch.cuda.is_available():
  torchdevice = torch.device('cuda')
  print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running on ' + str(torchdevice))

# see https://github.com/cornellius-gp/linear_operator/issues/20
!pip install gpytorch
import gpytorch
print('gpytorch version:',gpytorch.__version__)

# Dimension of the square sparse matrix
n = 30000
# Number of non-zero elements (up to duplicates)
nnz = 10000

# Create sparse SPD matrix
rowidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
colidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
itemidx = torch.vstack((rowidx,colidx))
values = torch.randn(nnz, device=torchdevice)
sparse_mat = torch.sparse_coo_tensor(itemidx, values, size=(n,n)).coalesce()
sparse_mat = sparse_mat.t() @ sparse_mat
id_mat = torch.sparse_coo_tensor(torch.stack((torch.arange(n,device=torchdevice),torch.arange(n,device=torchdevice))), torch.ones(n,device=torchdevice), (n,n))
sparse_mat = id_mat+sparse_mat
print('sparse_mat:',sparse_mat)

b = torch.randn(n, device=torchdevice)

# unsqueezing x0 fixes the issue
x0 = torch.zeros_like(b)

closure = lambda xtest : sparse_mat @ xtest
x = gpytorch.utils.linear_cg(closure, b, initial_guess=x0)
print('cg:',x)

Thanks for the catch. Would you be able to submit a PR to fix this?

Closed by #22