Run pip install jax_trees[cpu]
to run on CPU or pip install jax_trees[cuda]
to run on GPU.
This package is a simple implementations of Decision Trees, Random Forests and Gradient Boosted Trees using JAX.
The initial objective of this package was to learn to use JAX. I wanted a challenge to learn the JAX syntax by reimplementing common ML algorithms. Decision Trees and related ensemble models seemed like good candidates for this purpose.
JAX is a computing library with some specificities. Pure functions written in
JAX can be compiled into efficient XLA code using the jit
(just in time)
function. The compilation process itself can be long but once the function is
compiled it runs many time faster than the equivalent Python code, especially
if you call if many times. Compiled functions only accepts statically shaped
arrays as arguments. When you call a jitted function with different shapes
inputs, the function is recompiled, so we should try as much as possible to
have statically shaped arrays as inputs to such functions.
JAX also has some utilities like the vmap
to automatically vectorize
computations and the grad
to automatically compute the gradient of a scalar
function. These should prove useful when ensembling the trees as well as when
implementing gradient boosted trees.
To test the constraints and advantages of using JAX, my objective is to write
the ML models in the most JAX idiomatic way. To do that I set myself the
constraint of having a completely jitted fit
and predict
methods on all
models.
Fitting a decision tree requires recursively splitting the data at each node
into two subgroups until the specified max_depth
is reached. In a typical
implementation, we would have a TreeNode
class that has a split
method. One
would split the node in 2 and instantiate child nodes with the resulting
subgroups. In particular, the subgroup sizes would be particular to each model
and each dataset.
To select the best split for a node, many candidates are generated for each
feature column so that we have a matrix of shape (n_candidates, n_columns)
of
potential splits candidates to consider. For each split, the resulting score of
a split is computed (entropy or gini impurity for classification, variance for
regression). Computing the split
score for a single split can be done as
follow, where score_fn
is e.g. entropy
for classifiers and variance
for
regressors.
def compute_split_score(
feature_column: jnp.ndarray,
y: jnp.ndarray,
mask: jnp.ndarray,
split_value: float,
) -> float:
"""Compute the scores of data splits."""
left_mask, right_mask = split_mask(split_value, feature_column, mask)
left_score = score_fn(y, left_mask)
right_score = score_fn(y, right_mask)
n_left = jnp.sum(left_mask)
n_right = jnp.sum(right_mask)
avg_score = (n_left * left_score + n_right * right_score) / (
n_left + n_right
)
return avg_score
Applying the JAX function vmap
two times, one time to parallelize over the
candidates in a column and a second time to parallelize over the columns, we
can then easily parallelize the split computations to get a matrix of
information gains for many split candidates and then select the best splitting
point for a node. So far so good.
# paralelize across split point in a column
column_split_scores = vmap(
compute_split_score, in_axes=[None, None, None, 0]
)
# paralellize across columns
all_split_scores = vmap(
column_split_scores, in_axes=[1, None, None, 1], out_axes=1
)
all_scores = all_split_scores(X, y, mask, split_candidate_matrix)
A first issue arises however as some nodes cannot be split even though we have
not reached max_depth
. Indeed, the node may have too few data points to be
split (less that min_samples
). In a normal Python function, we would just add
a condition and split the node if it has more than min_samples
data points.
However, this cannot be done in a jitted function since the number and order of
operations should be known at compile time and we don't know which node will
not splittable before actually splitting the data.
To jit our fit method, we have to perform
For this purpose, we add a mask
array and a boolean is_leaf
to our
TreeNode
class. The mask
has a static shape (n_samples,)
and weights the
node data. If the mask sums to zero, it means that the node holds no data. If
the is_leaf
boolean is true, then we set masks of child nodes to zero.
is_leaf = jnp.maximum(depth + 1 - self.max_depth, 0) + jnp.maximum(
self.min_samples + 1 - jnp.sum(mask), 0
)
is_leaf = jnp.minimum(is_leaf, 1).astype(jnp.int8)
# zero-out child masks if current node is a leaf
left_mask *= 1 - is_leaf
right_mask *= 1 - is_leaf
The decision tree itself is a collection of TreeNode
objects. A TreeNode
is
simply a container holding information on how to split the data (is_leaf
,
split_column
, split_value
,...). I chose to represent the tree as a nested
structure where a node can be accessed doing Tree.nodes[depth][rank]
where
the rank is the position of the node at a given depth. This simple structure
allows to access a node children you just do nodes[depth+1][2*rank]
and
nodes[depth+1][2*rank+1]
.
Now to be able to jit the method itself, the DecisionTree
object should be a
pytree object so that self
, which is passed as the first argument to the
method, can be serialized as argument and returned from a jitted function.
To do that, we register the class using the
jax.tree_util.register_pytree_node_class
function and implement the
tree_flatten
and tree_unflatten
methods on the class.
Finally, since jitted functions are purely functional, the jitted fit can have
no side effect. This means that we cannot update the model parameters inplace
using fit. Instead, we return the fitted model as output of the fit method
fitted_model = model.fit(...)
.
The first implementation of the jitted fit function looked like this:
@register_pytree_node_class
class DecisionTree:
...
@jit
def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> DecisionTree:
nodes = defaultdict()
n_samples = X.shape[0]
next_masks = [jnp.ones((n_samples,))]
for depth in range(self.max_depth):
masks = next_masks
next_masks = []
for rank in range(2**depth):
current_mask = masks[depth][rank]
left_mask, right_mask, current_node = split_node(current_mask)
nodes[depth].append(current_node)
next_masks.extend([left_mask, right_mask])
self.nodes = nodes
return self
This worked well for small trees but the jit compilation time was too high to
be practical even for moderate tree depths. This is because for
loops are not
factorized during the jit compilation process. As a result, in the above
snippet, the XLA computation graph will grow exponentially with the depth of
the tree. For instance, the computational heavy split_node
function will be
traced 2**n times.
Using the lax.scan
primitive can help us circumvent the problem. It allows to
run a function over an iterable, passing along a carry over value, and get the
stacked results. The function is only traced once. We can use lax.scan
on the inner loop by passing the masks
array and getting masks for the next
depth and nodes for the current level. The updated code looks like this and,
jitting compilation time was dractically reduced (from ~2m to ~5s on a 4 depth
tree).
@register_pytree_node_class
class DecisionTree:
...
@jit
def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> DecisionTree:
nodes = defaultdict()
n_samples = X.shape[0]
def split_nodes(_carry, mask):
left_mask, right_mask, node = split_node(mask)
child_mask = jnp.stack([left_mask, right_mask], axis=0)
return _carry, (child_mask, node)
masks = [jnp.ones((n_samples,))]
for depth in range(self.max_depth):
_, (child_masks, nodes) = lax.scan(
fn=split_nodes,
init=None, # carry is not used
xs=masks
)
nodes[depth] = nodes
masks = jnp.reshape(child_masks, (-1, n_samples))
self.nodes = nodes
return self