Error when using MLR3 for LIME
DariusLR opened this issue · comments
Hello,
I am using a MLR3 model, and want to apply the predictions of the model to LIME. I have read that a WrappedModel, as being produced by MLR3, should be compatible with LIME, as long as you generate separate data tables for when using the explain and lime function.
This is what I did:
`##Representative sample
set.seed(123)
#Create regression task
regression_task = TaskRegr$new(id = "Reviews", backend = Yelp_Complete_ML, target = "useful_votes")
print(regression_task)
#Create RF regression learner
Random_forest = mlr_learners$get("regr.ranger")
print(Random_forest)
#Set sample
data_train = sample(regression_task$nrow, 0.75 * regression_task$nrow)
data_test = setdiff(seq_len(regression_task$nrow), data_train)
#Train Random Forest
Model <- Random_forest$train(regression_task, row_ids = data_train)
print(Random_forest$model)
#Test Random Forest
prediction_RF <- Random_forest$predict(regression_task, row_ids = data_test)
print(prediction_RF)
head(as.data.table(prediction_RF))
#Measuring performance
measure_RF = msr("regr.mse")
prediction_RF$score(measure_RF)
#Data frame for LIME
sample <- sample.split(Yelp_Complete_ML, SplitRatio = .75)
data_train2 = subset(Yelp_Complete_ML, sample == TRUE)
data_test2 = subset(Yelp_Complete_ML, sample == FALSE)
#Using LIME for Random Forest
explainer_RF <- lime(data_train2, Model)
explanation <- explain(data_test2, explainer_RF, n_features = 1)`
After trying to run the explain function, I got the following error message:
Error: The class of model must have a model_type method. See ?model_type to get an overview of models supported out of the box
Then, I used the as_regressor function on the 'Model', in order to convert it. But after doing so, and running the following code:
`#Using LIME for Random Forest
explainer_RF <- lime(data_train2, as_regressor(Model))
explanation <- explain(data_test2, explainer_RF, n_features = 1)`
I got the following error:
Error in UseMethod("predict") :
no applicable method for 'predict' applied to an object of class "lime_regressor"
Could you take a look in this and help me out?
Kind regards,
Darius