rikhuijzer / SIRUS.jl

Interpretable Machine Learning via Rule Extraction

Home Page:https://sirus.jl.huijzer.xyz/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Models do not handle missing classes on subsampling

ablaom opened this issue · comments

When running MLJ integration tests:

using MLJTestIntegration
import SIRUS

X, y = MLJTestIntegration.make_binary()
MLJTestIntegration.test(
    [SIRUS.StableForestClassifier,],
    X,
    y,
    mod = @__MODULE__,
    level=4,
    throw=true,
)
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
[ Info: Converting outcome classes ["B"] to [0.0].
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
[ Info: Converting outcome classes ["B", "O"] to [0.0, 1.0].
ERROR: DomainError with Value B not in pool. :

Stacktrace:
 [1] attempt(f::MLJTestIntegration.var"#5#8"{MLJBase.LogLoss{Float64}, SIRUS.MLJImplementation.StableForestClassifier, Vector{ComputationalResources.CPU1{Nothing}}, Tuple{NamedTuple{(:FL, :RW), Tuple{Vector{Float64}, Vector{Float64}}}, CategoricalArrays.CategoricalVector{String, UInt32, String, CategoricalArrays.CategoricalValue{String, UInt32}, Union{}}}}, message::String; throw::Bool)                                                                      
   @ MLJTestInterface ~/.julia/packages/MLJTestInterface/K1YSy/src/attemptors.jl:17
 [2] evaluation(::MLJBase.LogLoss{Float64}, ::SIRUS.MLJImplementation.StableForestClassifier, ::Vector{ComputationalResources.CPU1{Nothing}}, ::NamedTuple{(:FL, :RW), Tuple{Vector{Float64}, Vector{Float64}}}, ::Vararg{Any}; throw::Bool, verbosity::Int64)
   @ MLJTestIntegration ~/.julia/packages/MLJTestIntegration/J5lEw/src/attemptors.jl:24
 [3] test(::Vector{DataType}, ::NamedTuple{(:FL, :RW), Tuple{Vector{Float64}, Vector{Float64}}}, ::Vararg{Any}; mod::Module, level::Int64, throw::Bool, verbosity::Int64)
   @ MLJTestIntegration ~/.julia/packages/MLJTestIntegration/J5lEw/src/test.jl:312
 [4] top-level scope
   @ REPL[43]:1

caused by: DomainError with Value B not in pool. :

The failing test is a cross-validation test in which some folds are presumably missing one of the two classes present in the test set, which is very small. Ideally, a model should handle this eventuality. For example, in the binary case tested here, the prediction could be always the single class present in training target.

The StableRulesClassifier has the same issue.

I admit this is a bit of small-data corner case, but it would be great to address. I will need to remove SIRUS classifiers from MLJTestIntegration tests, pending resolution of this issue. All other MLJ classifiers do handle this corner case (outside of our ScikitLearn models).

Thanks for opening the issue. It is indeed a weak point, which I haven't figured out how to solve yet.

Ideally, a model should handle this eventuality.

So yes for the binary case, I could switch to returning a single class. Do you also know how I can figure out which classes to use in multiple folds? I'll try to look around in other packages and add another comment here if I find a solution.

Related to #25.

EDIT: This problem is likely caused by how the UnivariateFinite is constructed in src/forest.jl. It should (re)use the right pool or use MLJXGBoostInterface or so as an example.