Cases where `neq==0` with pytorch 0.4.1
caseus-viridis opened this issue · comments
Xin Wang commented
Need a fix for cases without equality constraints with pytorch 0.4.1.
For example:
import torch
from qpth.qp import QPFunction, QPSolvers
N, D = 4, 16
v = torch.randn(D)
L = torch.randn(N, D)
x = QPFunction(solver=QPSolvers.PDIPM_BATCHED)(
torch.matmul(L, L.t()),
torch.matmul(v, L.t()),
torch.eye(N),
torch.zeros(N),
torch.Tensor(),
torch.Tensor()
)
Stack trace:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-c5f8050018cf> in <module>()
11 torch.zeros(N),
12 torch.Tensor(),
---> 13 torch.Tensor()
14 )
___/lib/python3.6/site-packages/qpth/qp.py in forward(self, Q_, p_, G_, h_, A_, b_)
92 zhats, self.nus, self.lams, self.slacks = pdipm_b.forward(
93 Q, p, G, h, A, b, self.Q_LU, self.S_LU, self.R,
---> 94 self.eps, self.verbose, self.notImprovedLim, self.maxIter)
95 elif self.solver == QPSolvers.CVXPY:
96 vals = torch.Tensor(nBatch).type_as(Q)
___/lib/python3.6/site-packages/qpth/solvers/pdipm/batch.py in forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps, verbose, notImprovedLim, maxIter, solver)
176 rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
177 rz = torch.zeros(nBatch, nineq).type_as(Q)
--> 178 ry = torch.zeros(nBatch, neq).type_as(Q)
179
180 if solver == KKTSolvers.LU_FULL:
RuntimeError: sizes must be non-negative
Suggestions?
Brandon Amos commented
Thanks, I've fixed this and uploaded a new version to pypi.
Xin Wang commented
Awesome! Same fix for CVXPY
as well please:
qpth/qp.py, line 99: nus = torch.Tensor(nBatch, self.neq).type_as(Q)
Brandon Amos commented
Fixed there too and uploaded to pypi (but I didn't test it so let me know if I need to change anything else)