mlr3superlearner
An modern implementation of the Super Learner prediction algorithm using the mlr3 framework, and an adherence to the recommendations of Phillips, van der Laan, Lee, and Gruber (2023)
Installation
You can install the development version of mlr3superlearner from GitHub with:
# install.packages("devtools")
devtools::install_github("nt-williams/mlr3superlearner")
Example
library(mlr3superlearner)
#> Loading required package: mlr3learners
#> Loading required package: mlr3
library(mlr3extralearners)
# No hyperparameters
fit <- mlr3superlearner(mtcars, "mpg", c("mean", "glm", "svm", "ranger"), "continuous")
#> ℹ Setting cross-validation folds as 20
# With hyperparameters
fit <- mlr3superlearner(mtcars, "mpg",
list("mean", "glm", "xgboost", "svm", "earth",
list("nnet", trace = FALSE),
list("ranger", num.trees = 500, id = "ranger1"),
list("ranger", num.trees = 1000, id = "ranger2")),
"continuous")
#> ℹ Setting cross-validation folds as 20
fit
#> ══ `mlr3superlearner()` ════════════════════════════════════════════════════════
#> Risk Coefficients
#> regr.earth 10.578953 0
#> regr.glm 11.543067 0
#> regr.mean 38.831338 0
#> regr.nnet_and_trace_FALSE 37.493411 0
#> regr.ranger1 5.725750 0
#> regr.ranger2 5.641004 1
#> regr.svm 11.535177 0
#> regr.xgboost 226.947489 0
head(data.frame(pred = predict(fit, mtcars), truth = mtcars$mpg))
#> pred truth
#> 1 20.78731 21.0
#> 2 20.75069 21.0
#> 3 24.09735 22.8
#> 4 20.29066 21.4
#> 5 17.64259 18.7
#> 6 19.01303 18.1
Available learners
knitr::kable(available_learners("binomial"))
learner | mlr3_learner | mlr3_package | learner_package |
---|---|---|---|
mean | classif.featureless | mlr3 | stats |
glm | classif.log_reg | mlr3learners | stats |
glmnet | classif.glmnet | mlr3learners | glmnet |
cv_glmnet | classif.cv_glmnet | mlr3learners | glmnet |
knn | classif.kknn | mlr3learners | kknn |
nnet | classif.nnet | mlr3learners | nnet |
lda | classif.lda | mlr3learners | MASS |
naivebayes | classif.naive_bayes | mlr3learners | e1071 |
qda | classif.qda | mlr3learners | MASS |
ranger | classif.ranger | mlr3learners | ranger |
svm | classif.svm | mlr3learners | e1071 |
xgboost | classif.xgboost | mlr3learners | xgboost |
earth | classif.earth | mlr3extralearners | earth |
lightgbm | classif.lightgbm | mlr3extralearners | lightgbm |
randomforest | classif.randomForest | mlr3extralearners | randomForest |
bart | classif.bart | mlr3extralearners | dbarts |
c50 | classif.C50 | mlr3extralearners | C50 |
gam | classif.gam | mlr3extralearners | mgcv |
gaussianprocess | classif.gausspr | mlr3extralearners | kernlab |
glmboost | classif.glmboost | mlr3extralearners | mboost |
nloptr | classif.avg | mlr3pipelines | nloptr |
knitr::kable(available_learners("continuous"))
learner | mlr3_learner | mlr3_package | learner_package |
---|---|---|---|
mean | regr.featureless | mlr3 | stats |
glm | regr.lm | mlr3learners | stats |
glmnet | regr.glmnet | mlr3learners | glmnet |
cv_glmnet | regr.cv_glmnet | mlr3learners | glmnet |
knn | regr.kknn | mlr3learners | kknn |
nnet | regr.nnet | mlr3learners | nnet |
ranger | regr.ranger | mlr3learners | ranger |
svm | regr.svm | mlr3learners | e1071 |
xgboost | regr.xgboost | mlr3learners | xgboost |
earth | regr.earth | mlr3extralearners | earth |
lightgbm | regr.lightgbm | mlr3extralearners | lightgbm |
randomforest | regr.randomForest | mlr3extralearners | randomForest |
bart | regr.bart | mlr3extralearners | dbarts |
gam | regr.gam | mlr3extralearners | mgcv |
gaussianprocess | regr.gausspr | mlr3extralearners | kernlab |
glmboost | regr.glmboost | mlr3extralearners | mboost |