locuslab / qpth

A fast and differentiable QP solver for PyTorch.

Home Page:https://locuslab.github.io/qpth/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cases where `neq==0` with pytorch 0.4.1

caseus-viridis opened this issue · comments

Need a fix for cases without equality constraints with pytorch 0.4.1.
For example:

import torch
from qpth.qp import QPFunction, QPSolvers

N, D = 4, 16
v = torch.randn(D)
L = torch.randn(N, D)
x = QPFunction(solver=QPSolvers.PDIPM_BATCHED)(
    torch.matmul(L, L.t()), 
    torch.matmul(v, L.t()), 
    torch.eye(N), 
    torch.zeros(N), 
    torch.Tensor(), 
    torch.Tensor()
)

Stack trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-c5f8050018cf> in <module>()
     11     torch.zeros(N),
     12     torch.Tensor(),
---> 13     torch.Tensor()
     14 )

___/lib/python3.6/site-packages/qpth/qp.py in forward(self, Q_, p_, G_, h_, A_, b_)
     92             zhats, self.nus, self.lams, self.slacks = pdipm_b.forward(
     93                 Q, p, G, h, A, b, self.Q_LU, self.S_LU, self.R,
---> 94                 self.eps, self.verbose, self.notImprovedLim, self.maxIter)
     95         elif self.solver == QPSolvers.CVXPY:
     96             vals = torch.Tensor(nBatch).type_as(Q)

___/lib/python3.6/site-packages/qpth/solvers/pdipm/batch.py in forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps, verbose, notImprovedLim, maxIter, solver)
    176         rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
    177         rz = torch.zeros(nBatch, nineq).type_as(Q)
--> 178         ry = torch.zeros(nBatch, neq).type_as(Q)
    179
    180         if solver == KKTSolvers.LU_FULL:

RuntimeError: sizes must be non-negative

Suggestions?

Thanks, I've fixed this and uploaded a new version to pypi.

Awesome! Same fix for CVXPY as well please:
qpth/qp.py, line 99: nus = torch.Tensor(nBatch, self.neq).type_as(Q)

Fixed there too and uploaded to pypi (but I didn't test it so let me know if I need to change anything else)