Implementation of the proximal operator
AnchorBlues opened this issue · comments
In the source code, the proximal operator is defined as follows:
def _prox(self, beta, thresh):
The second argument is the threshold of the proximal operator.
As this site, the threshold should be the product of the learning rate and the L1 regularization parameter.
However, this method is called with feeding the second argument with only the L1 regularization parameter(reg_lambda * alpha
) as follows.
# Apply proximal operator
if self.fit_intercept:
beta[1:] = self._prox(beta[1:], reg_lambda * alpha)
else:
beta = self._prox(beta, reg_lambda * alpha)
I think reg_lambda * alpha
must be replaced with learning_rate * reg_lambda * alpha
. Otherwise, the model with L1 regularization will not be correctly trained when learning_rate
is not 1
.
In fact, the trained result of GLM
could be quite different from the one of sklearn
when the learning rate of GLM
is NOT 1
.
import pyglmnet
import sklearn
print(pyglmnet.__version__) # 1.1
print(sklearn.__version__) # 0.23.1
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Lasso
L1_reg = 1e-1
x_train, y_train = load_boston(return_X_y=True)
x_train = StandardScaler().fit_transform(x_train) / 2
# GLM with `learing_rate=0.2`
model_glm_lr02 = pyglmnet.GLM(distr='gaussian', alpha=1.0, reg_lambda=L1_reg, learning_rate=0.2, max_iter=10000)
model_glm_lr02.fit(x_train, y_train)
# GLM with `learing_rate=1`
model_glm_lr1 = pyglmnet.GLM(distr='gaussian', alpha=1.0, reg_lambda=L1_reg, learning_rate=1, max_iter=10000)
model_glm_lr1.fit(x_train, y_train)
# sklearn
model_sklearn = Lasso(alpha=L1_reg)
model_sklearn.fit(x_train, y_train)
print(model_glm_lr02.beta_)
# [ 0. -0. 0. -0. 0. 5.42708383
# 0. 0. 0. 0. -2.68700433 0.36198452
# -7.0861965 ]
print(model_glm_lr1.beta_)
# [-0.67815431 0.75967646 -0.04907772 1.2337921 -2.15730743 5.92672067
# 0. -3.46014723 0.01498839 0. -3.5466964 1.34164976
# -7.43306087]
print(model_sklearn.coef_)
# [-0.67808937 0.75891847 -0.048276 1.23386678 -2.15994748 5.92750325
# -0. -3.46026391 0.0155437 -0. -3.54744689 1.341734
# -7.43185641]
@AnchorBlues it's been a while that I dug into the math. Could you share a screenshot where it shows that the learning rate matters in the link you shared? Let me also solicit opinions from @pavanramkumar and @titipata
I think what is happening is that scikit-learn
uses an automatically determined learning rate for 'gaussian' that comes from the maximum eigen-value of (X^T)X. Can you verify that with learning_rate=0.2, how does our convergence plot look? I suspect the convergence is too slow.
Could you share a screenshot where it shows that the learning rate matters in the link you shared?
The following site may be easier to understand.
http://www.stat.cmu.edu/~ryantibs/convexopt-S15/scribes/08-prox-grad-scribed.pdf
The target section is 8.1.3.( Iterative soft-thresholding algorithm (ISTA)
). Here, t
is the learning rate and lambda
is the L1 regularization parameter.
As the formula (8.11) and (8.12) shows, the optimization of parameters of the Lasso is executed as follows (The same method applies to the Elastic-Net):
- Update the parameters using the gradient descend method ignoring the L1 regularization term.
- Feed the
soft-thresholding operator
with the updated parameters.- The threshold of the
soft-thresholding operator
is the product of the learning rate and the L1 regularization parameter.
- The threshold of the
Can you verify that with learning_rate=0.2, how does our convergence plot look?
I updated the version of pyglmnet
to 1.2.dev0
to use the method plot_convergence
(the trained results was not changed) and plotted the convergence plot of the case learning_rate=0.2
.
Converged in 443 iterations.
humm ... indeed you are right. Have you verified that the fix solves the comparison problem with scikit-learn
? Would be great if you can make a pull request! Thank you so much.
Have you verified that the fix solves the comparison problem with scikit-learn ?
Yes, I have verified it.
Would be great if you can make a pull request!
I made a pull request.
Please confirm.
Thanks @AnchorBlues for the careful verification and fix!
closed by #384