jpmml / jpmml-lightgbm

Java library and command-line application for converting LightGBM models to PMML

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

support cross_entropy as objective for lightgbm

panlanfeng opened this issue · comments

Hello,

Thanks for the great work on this project.
I was wondering if supporting cross entropy objective in your supporting roadmap or not.
I have a use case that I need to use numeric probability labels in [0, 1]. I got the following error message. Could you help to take a look? thanks!

Jun 30, 2021 3:56:41 AM org.jpmml.lightgbm.Main run
INFO: Loading GBDT..
Jun 30, 2021 3:56:41 AM org.jpmml.lightgbm.Main run
SEVERE: Failed to load GBDT
java.lang.IllegalArgumentException: cross_entropy
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:529)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

Exception in thread "main" java.lang.IllegalArgumentException: cross_entropy
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:529)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

Does the cross_entropy objective function (aka xentropy) also use the sigmoid function for calculating probabilities?

If it does, then try simply inserting cross_entropy here:
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.9/src/main/java/org/jpmml/lightgbm/GBDT.java#L523

Something like this:

switch(objective){
  // BinaryLogloss
  case "binary":
  case "cross_entropy":
    return new BinomialLogisticRegression(average_output, config.getDouble("sigmoid"));
}

If you rebuild the project, and re-do the conversion, then does the PMML model make correct predictions or not?

Looks like cross_entropy is not using sigmoid. I made the change as you suggested and get the following error when converting

Jul 01, 2021 12:41:01 AM org.jpmml.lightgbm.Main run
INFO: Loading GBDT..
Jul 01, 2021 12:41:01 AM org.jpmml.lightgbm.Main run
SEVERE: Failed to load GBDT
java.lang.IllegalArgumentException: sigmoid
        at org.jpmml.lightgbm.Section.get(Section.java:106)
        at org.jpmml.lightgbm.Section.get(Section.java:100)
        at org.jpmml.lightgbm.Section.getDouble(Section.java:74)
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:525)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

Exception in thread "main" java.lang.IllegalArgumentException: sigmoid
        at org.jpmml.lightgbm.Section.get(Section.java:106)
        at org.jpmml.lightgbm.Section.get(Section.java:100)
        at org.jpmml.lightgbm.Section.getDouble(Section.java:74)
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:525)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

According to this line, cross entropy is directly doing the calculation instead of calling sigmoid function and it does not take sigmoid parameter as in binary classification.

I was able to make it generate the correct score after making the following change to
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.9/src/main/java/org/jpmml/lightgbm/GBDT.java#L523

            case "cross_entropy":
                return new BinomialLogisticRegression(average_output, 1.0 );

I can make a CR for this change if it looks OK to you.

return new BinomialLogisticRegression(average_output, 1.0 );

Yes, that appears to be the solution. There is no need for an explicit sigmoid parameter, because the coefficient is hard-coded as 1.

I can make a CR for this change if it looks OK to you.

Not needed - I'll do a proper cross_entropy support with test cases for the next release myself.

In the meantime, you can keep using your patched codebase.

Thanks!
I was also wondering if it is possible to also add this cross entropy support to history version 1.2.* as well?
Ask because our team are still using version 1.2.*.
It is OK if there is no such plan.

I was also wondering if it is possible to also add this cross entropy support to history version 1.2.* as well?

I'll see if the 1.2.X development branch has the same API available that is being "touched" here. If it is, I'll implement the change in 1.2.X, and then merge forward to 1.3.X.

The fix is available both in JPMML-LightGBM 1.2.15 and 1.3.10.