graphcore-research / tensorflow-jax-experimental

TensorFlow XLA backend of experimental JAX on Mk2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

🔴 Non-official experimental 🔴 IPU XLA TensorFlow/JAX XLA backend

This repository is a non-official experimental fork of Graphcore IPU TensorFlow repository (the latter being a modified version of TensorFlow supporting Graphcore IPUs).

The goal of this repository is to implement the additional PRJT layer on top of Graphcore Poplar XLA backend, the former being necessary to compile and run JAX on IPUs. This is NOT an additional non-official TensorFlow version for IPUs.

Experimental JAX on IPU is directly pulling the XLA backend source code from this repository, and compiling the proper jaxlib Python binary wheel. Independent compilation of this repository using bazel is only supported in order to directly test bug fixes or additional features on the IPU XLA backend or PJRT client.

Compilation

The stable branch requires the following configuration: Ubuntu 20.04, Graphcore Poplar SDK 3.1 and Bazel 5.1.1.

For the development of jaxlib on IPU, the targets of interest are:

  • IPU Poplar XLA backend: //tensorflow/compiler/plugin:plugin
  • XLA Python client: //tensorflow/compiler/xla/python:xla_client
  • IPU PJRT client: //tensorflow/compiler/plugin/poplar/xla_client:ipu_xla_client

These targets can be compiled as following:

export PATH=$HOME/bin:$PATH  # in case of local Bazel install
export TF_POPLAR_BASE=...    # Poplar install path. e.g. /opt/poplar/ or ${POPLAR_SDK_ENABLED}
python configure.py
bazel build --config=monolithic //tensorflow/compiler/plugin/poplar/xla_client:ipu_xla_client

Note that the option --config=monolithic is here to reflect the compilation configuration of jaxlib, which generates a single monolithic shared library.

Additional useful bazel parameters:

  • --output_user_root: Update bazel directory (e.g. for a faster local disk);
  • ln -s /localdata/paulb/bazel/ /nethome/paulb/.cache/bazel: Bazel cache on fast disk;

Running unit tests

For the purpose of supporting JAX, here are the test targets of interest:

  • Poplar XLA backend (IPU specific) unit tests: //tensorflow/compiler/plugin/poplar:all_tests
  • XLA general unit tests, using IPU Poplar backend: //tensorflow/compiler/tests:poplar_tests
  • XLA client unit tests: //tensorflow/compiler/xla/client/lib:poplar_tests
  • IPU PJRT client unit tests: //tensorflow/compiler/plugin/poplar/xla_client/tests:all_tests

All the previous test targets can be run on the IPU model using the following commands:

bazel test --config=monolithic --jobs=16 --verbose_failures --cache_test_results=no --test_timeout=240,360,900,3600 --test_size_filters=small,medium,large --flaky_test_attempts=1 --test_output=all --test_env='TF_POPLAR_FLAGS=--use_ipu_model --ipu_model_tiles=8 --max_compilation_threads=1 --max_infeed_threads=2' //tensorflow/compiler/plugin/poplar/xla_client/tests:all_tests
  • Using IPU hardware requires an additional test_env mapping: --test_env='IPUOF_VIPU_API_PARTITION_ID=xxx.
  • Additional logs can be outputted using: --test_env='POPLAR_LOG_LEVEL=DEBUG' --test_env='TF_CPP_MIN_LOG_LEVEL=0'.

Failing unit tests should be documented as a Github ticket.

Additional documentation

About

TensorFlow XLA backend of experimental JAX on Mk2

License:Apache License 2.0


Languages

Language:C++ 63.7%Language:Python 21.4%Language:MLIR 5.1%Language:Starlark 3.7%Language:HTML 2.3%Language:C 1.0%Language:Go 1.0%Language:Java 0.5%Language:Jupyter Notebook 0.4%Language:Shell 0.3%Language:Dockerfile 0.2%Language:Objective-C++ 0.1%Language:Objective-C 0.1%Language:CMake 0.1%Language:Smarty 0.0%Language:Swift 0.0%Language:Batchfile 0.0%Language:C# 0.0%Language:Ruby 0.0%Language:Assembly 0.0%Language:SourcePawn 0.0%Language:Perl 0.0%Language:LLVM 0.0%Language:Pawn 0.0%Language:Roff 0.0%Language:Cython 0.0%Language:Makefile 0.0%Language:Vim Snippet 0.0%