nt-williams / mlr3superlearner

Super learner fitting and prediction using mlr3

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

mlr3superlearner

Lifecycle: experimental R-CMD-check

An modern implementation of the Super Learner prediction algorithm using the mlr3 framework, and an adherence to the recommendations of Phillips, van der Laan, Lee, and Gruber (2023)

Installation

You can install the development version of mlr3superlearner from GitHub with:

# install.packages("devtools")
devtools::install_github("nt-williams/mlr3superlearner")

Example

library(mlr3superlearner)
#> Loading required package: mlr3learners
#> Loading required package: mlr3
library(mlr3extralearners)

# No hyperparameters
fit <- mlr3superlearner(mtcars, "mpg", c("mean", "glm", "svm", "ranger"), "continuous")
#> ℹ Setting cross-validation folds as 20

# With hyperparameters
fit <- mlr3superlearner(mtcars, "mpg", 
                        list("mean", "glm", "xgboost", "svm", "earth",
                             list("nnet", trace = FALSE),
                             list("ranger", num.trees = 500, id = "ranger1"),
                             list("ranger", num.trees = 1000, id = "ranger2")), 
                        "continuous")
#> ℹ Setting cross-validation folds as 20

fit
#> ══ `mlr3superlearner()` ════════════════════════════════════════════════════════
#>                                 Risk Coefficients
#> regr.earth                 10.578953            0
#> regr.glm                   11.543067            0
#> regr.mean                  38.831338            0
#> regr.nnet_and_trace_FALSE  37.493411            0
#> regr.ranger1                5.725750            0
#> regr.ranger2                5.641004            1
#> regr.svm                   11.535177            0
#> regr.xgboost              226.947489            0

head(data.frame(pred = predict(fit, mtcars), truth = mtcars$mpg))
#>       pred truth
#> 1 20.78731  21.0
#> 2 20.75069  21.0
#> 3 24.09735  22.8
#> 4 20.29066  21.4
#> 5 17.64259  18.7
#> 6 19.01303  18.1

Available learners

knitr::kable(available_learners("binomial"))
learner mlr3_learner mlr3_package learner_package
mean classif.featureless mlr3 stats
glm classif.log_reg mlr3learners stats
glmnet classif.glmnet mlr3learners glmnet
cv_glmnet classif.cv_glmnet mlr3learners glmnet
knn classif.kknn mlr3learners kknn
nnet classif.nnet mlr3learners nnet
lda classif.lda mlr3learners MASS
naivebayes classif.naive_bayes mlr3learners e1071
qda classif.qda mlr3learners MASS
ranger classif.ranger mlr3learners ranger
svm classif.svm mlr3learners e1071
xgboost classif.xgboost mlr3learners xgboost
earth classif.earth mlr3extralearners earth
lightgbm classif.lightgbm mlr3extralearners lightgbm
randomforest classif.randomForest mlr3extralearners randomForest
bart classif.bart mlr3extralearners dbarts
c50 classif.C50 mlr3extralearners C50
gam classif.gam mlr3extralearners mgcv
gaussianprocess classif.gausspr mlr3extralearners kernlab
glmboost classif.glmboost mlr3extralearners mboost
nloptr classif.avg mlr3pipelines nloptr
knitr::kable(available_learners("continuous"))
learner mlr3_learner mlr3_package learner_package
mean regr.featureless mlr3 stats
glm regr.lm mlr3learners stats
glmnet regr.glmnet mlr3learners glmnet
cv_glmnet regr.cv_glmnet mlr3learners glmnet
knn regr.kknn mlr3learners kknn
nnet regr.nnet mlr3learners nnet
ranger regr.ranger mlr3learners ranger
svm regr.svm mlr3learners e1071
xgboost regr.xgboost mlr3learners xgboost
earth regr.earth mlr3extralearners earth
lightgbm regr.lightgbm mlr3extralearners lightgbm
randomforest regr.randomForest mlr3extralearners randomForest
bart regr.bart mlr3extralearners dbarts
gam regr.gam mlr3extralearners mgcv
gaussianprocess regr.gausspr mlr3extralearners kernlab
glmboost regr.glmboost mlr3extralearners mboost

About

Super learner fitting and prediction using mlr3

License:GNU General Public License v3.0


Languages

Language:R 100.0%