xiaoa6435 / spark-bart

An implementation of bayes additive regression tree model atop Apache Spark

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

spark-bart

A pure scala/spark implementation of the BART(bayes additive regressions trees model of Chipman et al and related model, like BACT(bayes additive classification tree), BCF(bayes causal forest) etc

example

import org.apache.spark.ml.linalg.Vectors
import math.{Pi, sin}
import util.Random

import org.apache.spark.ml.regression.{BayesAdditiveTreeRegressionModel, BayesAdditiveTreeRegressor}
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
import org.apache.spark.ml.evaluation.RegressionEvaluator

val sampleCnt = (1e3).toInt
val p = 10

// Friedman’s test function with a category 
val df = spark.createDataFrame(
  Range(0, sampleCnt).map{i => 
    val x = Array.tabulate(p){i => 
      if(i == 9) Random.nextInt(5).toDouble else Random.nextDouble
    }
    val features = Vectors.dense(x)
    val label = (10 * math.sin(math.Pi * x(0) * x(1)) + 
             20 * math.pow(x(2) - 0.5, 2) + 
             10 * x(3) + 5 * x(4) + Random.nextGaussian)
  (features, label)
}).toDF("features", "label")

val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))
val BART = new BayesAdditiveTreeRegressor().
  setNumBurn(1000).
  setNumSim(100).
  setParallelChainCnt(4).
  setNumThin(2).
  setCategoryFeatureIndexes(Array(9)).
  setCategoryFeatureArity(Array(5))
val bartModel = BART.fit(trainingData)
val bartPred = bartModel.transform(testData)

val evaluator = new RegressionEvaluator().
  setLabelCol("label").
  setPredictionCol("prediction").
  setMetricName("rmse")

val bartRMSE = evaluator.evaluate(bartPred)
println(s"bart: Root Mean Squared Error (RMSE) on test data = $bartRMSE")
//bart: Root Mean Squared Error (RMSE) on test data = 1.83

val GBT = new GBTRegressor()
val GBTModel = GBT.fit(trainingData)
val gbtPred = GBTModel.transform(testData)
val gbtRMSE = evaluator.evaluate(gbtPred)
println(s"gbt: Root Mean Squared Error (RMSE) on test data = $gbtRMSE")
//gbt: Root Mean Squared Error (RMSE) on test data = 2.31

bartModel.asInstanceOf[BayesAdditiveTreeRegressionModel].
  write.overwrite.save("bart-model")

Related projects

References

About

An implementation of bayes additive regression tree model atop Apache Spark

License:Apache License 2.0


Languages

Language:Scala 100.0%