fromtorchensembleimportensemble_method# import ensemble method (e.g., VotingClassifier)# Define the ensemblemodel=ensemble_method(estimator=base_estimator, # class of your base estimatorn_estimators=10, # the number of base estimatorscuda=True) # whether to use GPU# Load datatrain_loader=DataLoader(...)
test_loader=DataLoader(...)
# Set the optimizermodel.set_optimizer("Adam", # type of the parameter optimizerlr=learning_rate, # learning rate of the parameter optimizerweight_decay=weight_decay) # weight decay of the parameter optimizer# Trainmodel.fit(train_loader, # training dataepochs=epochs) # the number of training epochs# Evaluateaccuracy=model.predict(test_loader)
Benchmarks
Classification on CIFAR-10
The table below presents the classification accuracy of different
ensemble classifiers on the testing data of CIFAR-10
Each classifier uses 10 LeNet-5 model (with RELU activation and
Dropout) as the base estimators
Each base estimator is trained over 100 epochs, with batch size
128, learning rate 1e-3, and weight decay 5e-4
Experiment results can be reproduced by running
./examples/classification_cifar10_cnn.py
Model Name
Params (MB)
Testing Acc (%)
Improvement (%)
Single LeNet-5
0.32
73.04
0
FusionClassifier
3.17
78.75
+5.71
VotingClassifier
3.17
80.08
+7.04
BaggingClassifier
3.17
78.75
+5.71
GradientBoostingClassifier
3.17
80.82
+7.78
Regression on YearPredictionMSD
The table below presents the mean squared error (MSE) of different
ensemble regressors on the testing data of YearPredictionMSD
Each regressor uses 10 multi-layered perceptron (MLP) model (with
RELU activation and Dropout) as the base estimators, and the network
architecture is fixed as Input-128-128-Output
Each base estimator is trained over 50 epochs, with batch size
256, learning rate 1e-3, and weight decay 5e-4
Experiment results can be reproduced by running
./examples/regression_YearPredictionMSD_mlp.py
Model Name
Params (MB)
Testing MSE
Improvement
Single MLP
0.11
0.83
FusionRegressor
1.08
0.73
-0.10
VotingRegressor
1.08
0.69
-0.14
BaggingRegressor
1.08
0.70
-0.13
GradientBoostingRegressor
1.08
0.71
-0.12
Package dependencies
joblib>=0.11
scikit-learn>=0.23.0
torch>=0.4.1
torchvision>=0.2.2
About
Implementation of ensemble methods in Pytorch to boost the performance of your model.