saglam / xlaclient-rs

An attempt at writing an XLA client in Rust

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

xlaclient-rs

In reinforcement learning (RL), one often designs a policy network for an agent, which outputs a probability distribution over the actions the agent can take, given the state of the world, hopefully enabling the agent to achieve certain goals.

Any such policy network induces a probability distribution over the "path space" of the world: starting from an initial state, the agent samples an action according to the network and takes this action, and keeps repeating this until the simulation is over.

Suppose now we are given a policy network \inline W, with the induced path space distribution \inline P. We may run many many parallel simulations by letting the agent behave according to \inline W, thereby drawing many samples from \inline P, say obtaining

\displaystyle{\{p_1, p_2,\ldots, p_N\}}

where \inline p_i\in\mathrm{supp}(P) is a particular "path" the agent took in simulation number \inline i. Suppose now we can assign a quality score \inline q(p_i) to each path denoting how well the agent did to achieve the particular goal we want.

Now consider the "empirical distribution" on the path space obtained as follows: Sample \inline i\in\{1,2,\ldots, N\} with probability proportional to \inline q(p_i) and output \inline p_i. Call this distribution over the paths \inline Q. In words, \inline Q is a reweigting (of an empirical sample) of \inline P.

(One interesting thing is that while \inline P was Markovian with respect to world state, \inline Q is not necessarily so. It is mainly the sampling that breaks the Markov property; it we just re-weighted the true \inline P according to the final state of the world, the resulting distribution would still be Markovian.)

In a certain sense, this simulation we made revealed to us \inline \mathbf{D}(Q\,\|\,P) bits of information about the shortcomings of our policy network \inline N, where \inline \mathbf{D}(\cdot\|\cdot) is the Kullback-Leibler divergence. For simplicity in notation, I will assume that in our world, all paths are the same length; though we can easily remove this assumption. Writing

\displaystyle{ \mathbf{D}(Q\,\|\,P) =
    \sum_{i=1}^{|p_0|} \mathbf{D}(Q_i\,|\,Q_{\lt  i}\,\|\,P_i\,|\,P_{\lt  i}),}

this experiment gives us \inline N\times |p_0| data points with which we can improve \inline W. Given these sample paths \inline \{p_1,\ldots,p_N\}, our goal is to fit \inline Ws output to \inline Q_i\,|\,Q_{\lt  i}, say via SGD, so that each term in the above summation approaches zero, therefore the entire divergence on the path space approaches 0.

Doing this simulation and then SGD training and simulation with improved network and so on alternately gives us a powerful way to discover extremely competitive strategies.

Why an XLA client in Rust

In a certain sense in RL we develop a better understanding about the world by making many simulations with a more competent agent. This results in a self-reinforcing cycle, which leads to the name RL. Making many parallel simulations quickly involves two things

  1. Feeding the world state to the network \inline W and getting back a distribution on actions

  2. Being able to run a time step of the "world" and our agent in the simulation.

While Python is extremely well suited for (1), it is the second item where Python makes things very slow, difficult, or in some cases impossible.

This repo is my attempt at writing an XLA client in Rust so that one can perform (2) efficiently while still being able to do (1) quickly by offloading the network inference to powerful accelerators such as Google Cloud GPUs or TPUs.

Note: DeepMind's open_spiel is a RL platform, and thus needs to solve the same fast inference problem I outlined. They solve this by bringing up a full TensorFlow session and running the XLA compilation through the @tf.function mechanism. For the purposes of this project, that route is not ideal as TensorFlow session is a ginormous dependency that would require a build cluster; further @tf.function interface currently does not seem to expose all the compilation options XLA provides.

About

An attempt at writing an XLA client in Rust