Doohae / tpu-starter

Everything you want to know about Google Cloud TPU

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPU Starter

Everything you want to know about Google Cloud TPU

This project is inspired by Cloud Run FAQ, a community-maintained knowledge base of another Google Cloud product.

1. Community

As of 23 Feb 2022, there is no official chat group for Cloud TPUs. You can join the @cloudtpu chat group on Telegram or TPU Podcast on Discord, which are connected with each other.

2. Introduction to TPU

2.1. Why TPU?

TL;DR: TPU is to GPU as GPU is to CPU.

TPU is a special hardware designed specifically for machine learning. There is a performance comparison in Hugging Face Transformers:

Moreover, for researchers, the TRC program provides free TPU. As far as I know, if you have ever been concerned about the computing resources for training models, this is the best solution. For more details on the TRC program, see below.

2.2. TPU is so good, why haven't I seen many people using it?

If you want to use PyTorch, TPU may not be suitable for you. TPU is poorly supported by PyTorch. In one of my experiments, one batch took about 14 seconds to run on CPU, but over 4 hours to run on TPU. Twitter user @mauricetpunkt also thinks PyTorch's performance on TPUs is bad.

Another problem is that although a single TPU v3-8 device has 8 cores (16 GiB memory for each core), you need to write extra code to make use of all the 8 cores (see below). Otherwise, only the first core is used.

2.3. I know TPU is good now. Can I touch a real TPU?

Unfortunately, in most cases you cannot touch a TPU physically. TPU is only available through cloud services.

2.4. How do I get access to TPU?

You can create TPU instances on Google Cloud Platform. For more information on setting up TPU, see below.

You can also use Google Colab, but I don't recommend this way. Moreover, if you get free access to TPU from the TRC program, you will be using Google Cloud Platform, not Google Colab.

2.5. What does it mean to create a TPU instance? What do I actually get?

After creating a TPU v3-8 instance on Google Cloud Platform, you will get a Ubuntu 20.04 cloud server with sudo access, 96 cores, 335 GiB memory and one TPU device with 8 cores (128 GiB TPU memory in total).

This is similar to the way we use GPU. In most cases, when you use a GPU, you use a Linux server that connects with a GPU. When you use a TPU, you use a Linux server that connects with a TPU.

3. Introduction to the TRC Program

3.1. How to apply for the TRC program?

Besides its homepage, Shawn has written a wonderful article about the TRC program in google/jax#2108. Anyone who is interested in TPU should read it immediately.

3.2. Is it really free?

At the first three months, it is completely free because all the fees are covered by Google Cloud free trial. After that, I pay only about HK$13.95 (approx. US$1.78) for one month for the outbound Internet traffic.

4. Create a TPU VM Instance

4.1. Modify VPC firewall

You need to loosen the restrictions of the firewall so that Mosh and other programs will not be blocked.

Open the Firewall management page in VPC network.

Click the button to create a new firewall rule.

Set name to 'allow-all', targets to 'All instances in the network', source filter to 0.0.0.0/0, protocols and ports to 'Allow all', and then click 'Create'.

4.2. Create the instance

Open Google Cloud Platform, navigate to the TPU management page.

Click the console button on the top-right corner to activate Cloud Shell.

In Cloud Shell, type the following command to create a Cloud TPU VM v3-8 with TPU software version v2-nightly20210914:

gcloud alpha compute tpus tpu-vm create node-1 --project tpu-develop --zone=europe-west4-a --accelerator-type=v3-8 --version=v2-nightly20210914

If the command fails because there are no more TPUs to allocate, you can re-run the command again.

4.3. Add public key to the server

In Cloud Shell, login to the Cloud VM by the gcloud command:

gcloud alpha compute tpus tpu-vm ssh node-1 --zone europe-west4-a

After logging in, add your public key to ~/.ssh/authorized_keys.

5. Environment Setup

This section assumes you have no previous knowledge about developing on a server. You can skip this section if you are already familiar with developing on a server and have your preferred setting.

5.1. Install common packages

sudo apt update
sudo apt upgrade -y
sudo apt install -y golang neofetch zsh mosh byobu
sudo reboot

5.2. Install Python 3.10

Unfortunately, Python shipped with Ubuntu 20.04 LTS is Python 3.8, so you need to install Python 3.10 manually.

sudo apt install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt install -y python3.10-full python3.10-dev

5.3. Install Oh My Zsh

Oh My Zsh makes the terminal much easier to use.

To install Oh My Zsh, run the following command:

sh -c "$(curl -fsSL https://raw.github.com/ohmyzsh/ohmyzsh/master/tools/install.sh)"

5.4. Change timezone

timedatectl list-timezones
sudo timedatectl set-timezone Asia/Hong_Kong  # change to your timezone

5.5. Create venv

python3.10 -m venv ~/.venv310
source ~/.venv310/bin/activate

You need to run the source command every time you open a shell.

5.6. Install JAX with TPU support

pip install -U pip
pip install -U wheel
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

5.7. Install common libraries

Clone this repository. In the root directory of this repository, run:

pip install -r requirements.txt

5.8. Install Tensorflow and Tensorboard Plugin Profile

Although we are using JAX, we need to install Tensorflow as well to make jax.profiler work. Otherwise you will get an error:

E external/org_tensorflow/tensorflow/python/profiler/internal/python_hooks.cc:369] Can't import tensorflow.python.profiler.trace

You cannot install Tensorflow in the regular way because it is not built with TPU support.

Installation method:

wget https://gist.github.com/ayaka14732/4954f64b7246beafabb45b636d96e92a/raw/d518753d166f3b77009d1f228101d93ff733d0d2/tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl.0 https://gist.github.com/ayaka14732/4954f64b7246beafabb45b636d96e92a/raw/d518753d166f3b77009d1f228101d93ff733d0d2/tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl.1
cat tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl.0 tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl.1 > tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl
rm -f tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl.0 tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl.1
pip install tensorflow-2.10.0-cp310-cp310-linux_x86_64.whl

See gist.

5.9. Set up Mosh and Byobu

If you connect to the server directly with SSH, there is a risk of loss of connection. If this happens, the training script you are running in the foreground will be terminated.

Mosh and Byobu are two programs to solve this problem. Byobu will ensure that the script continues to run on the server even if the connection is lost, while Mosh guarantees that the connection will not be lost.

Install Mosh on your local device, then log in into the server with:

mosh tpu1 -- byobu

You can learn more about Byobu from the video Learn Byobu while listening to Mozart.

5.10. Set up VSCode Remote-SSH

Open VSCode. Open the 'Extensions' panel on the left. Search for 'Remote - SSH' and install.

Press F1 to open the command palette. Type 'ssh', then select 'Remote-SSH: Connect to Host...'. Input the server name you would like to connect and press Enter.

Wait for VSCode to be set up on the server. After it is finished, you can develop on the server using VSCode.

5.11. How can I verify that the TPU is working?

Run this command:

python3 -c 'import jax; print(jax.devices())'  # should print TpuDevice

Note that we are using python3 instead of python here, so the command also works even without activating venv.

You can also run this command to link python to python3 by default, but I do not recommend it:

sudo apt install -y python-is-python3

This is because we should always use venv to run our projects. When the python command is Python 2, if we forget to source venv, in most cases the command will fail, and this will remind us to source venv.

TODO: If TPU is not working...

See also google/jax#9220 (comment).

6. JAX Basics

6.1. Why JAX?

The three popular deep learning libraries supported by Hugging Face Transformers are JAX, PyTorch and TensorFlow.

As mentioned earlier, PyTorch is poorly supported on TPU. For Tensorflow and JAX, I regard JAX as the next generation and simplified version of Tensorflow. JAX is easier to use than Tensorflow.

JAX uses the same APIs as NumPy. There are also a number of mutually compatible libraries built on top of JAX. A comprehensive list of the JAX ecosystem can be found at n2cholas/awesome-jax.

6.2. Compute gradients with jax.grad

6.3. Load training data to CPU, then send batches to TPU

6.4. Data parallelism on 8 TPU cores

6.4.1. Basics of jax.pmap

There are four key points here.

1. params and opt_state should be replicated across the devices:

replicated_params = jax.device_put_replicated(params, jax.devices())

2. data and labels should be split to the devices:

n_devices = jax.device_count()
batch_size, *data_shapes = data.shape
assert batch_size % n_devices == 0, 'The data cannot be split evenly to the devices'
data = data.reshape(n_devices, batch_size // n_devices, *data_shapes)

3. Decorate the target function with jax.pmap:

@partial(jax.pmap, axis_name='num_devices')

4. In the loss function, use jax.lax.pmean to calculate the mean value across devices:

grads = jax.lax.pmean(grads, axis_name='num_devices')  # calculate mean across devices

See 01-basics/test_pmap.py for a complete working example.

See also https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#example.

6.4.2. What if I want to have randomness in the update function?

key, subkey = (lambda keys: (keys[0], keys[1:]))(rand.split(key, num=9))

Note that you cannot use the regular way to split the keys:

key, *subkey = rand.split(key, num=9)

Because in this way, subkey is a list rather than an array.

6.4.3. What if I want to use optax optimizers in the update function?

opt_state should be replicated as well.

6.5. Use optimizers from Optax

6.6. Freeze certain model parameters

Use optax.set_to_zero together with optax.multi_transform.

params = {
    'a': { 'x1': ..., 'x2': ... },
    'b': { 'x1': ..., 'x2': ... },
}

param_labels = {
    'a': { 'x1': 'freeze', 'x2': 'train' },
    'b': 'train',
}

optimizer_scheme = {
    'train': optax.adam(...),
    'freeze': optax.set_to_zero(),
}

optimizer = optax.multi_transform(optimizer_scheme, param_labels)

See Freeze Parameters Example for details.

6.7. Integration with Hugging Face Transformers

Hugging Face Transformers

7. Best Practices

7.1. About TPU

7.1.1. Prefer Google Cloud Platform to Google Colab

Google Colab only provides TPU v2-8 devices, while on Google Cloud Platform you can select TPU v2-8 and TPU v3-8.

Besides, on Google Colab you can only use TPU through the Jupyter Notebook interface. Even if you log in into the Colab server via SSH, it is a docker image and you don't have root access. On Google Cloud Platform, however, you have full access to the TPU VM.

If you really want to use TPU on Google Colab, you need to run the following script to set up TPU:

import jax
from jax.tools.colab_tpu import setup_tpu

setup_tpu()

devices = jax.devices()
print(devices)  # should print TpuDevice

7.1.2. Prefer TPU VM to TPU node

When you are creating a TPU instance, you need to choose between TPU VM and TPU node. Always prefer TPU VM because it is the new architecture in which TPU devices are connected to the host VM directly. This will make it easier to set up the TPU device.

7.1.3. Run Jupyter Notebook on TPU VM

After setting up Remote-SSH, you can work with Jupyter notebook files in VSCode.

Alternatively, you can run a regular Jupyter Notebook server on the TPU VM, forward the port to your PC and connect to it. However, you should prefer VSCode because it is more powerful, offers better integration with other tools and is easier to set up.

7.1.4. Share files across multiple TPU VM instances

TPU VM instances in the same zone are connected with internal IPs, so you can create a shared file system using NFS.

7.1.5. Monitor TPU usage

7.1.6. Start a server on TPU VM

Example: Tensorboard

Although every TPU VM is allocated with a public IP, in most cases you should expose a server to the Internet because it is insecure.

Port forwarding via SSH

ssh -C -N -L 127.0.0.1:6006:127.0.0.1:6006 tpu1

7.2. About JAX

7.2.1. Import convention

You may see two different kind of import conventions. One is to import jax.numpy as np and import the original numpy as onp. Another one is to import jax.numpy as jnp and leave original numpy as np.

On 16 Jan 2019, Colin Raffel wrote in a blog article that the convention at that time was to import original numpy as onp.

On 5 Nov 2020, Niru Maheswaranathan said in a tweet that he thinks the convention at that time was to import jax as jnp and to leave original numpy as np.

TODO: Conclusion?

7.2.2. Manage random keys in JAX

The regular way is this:

key, *subkey = rand.split(key, num=4)
print(subkey[0])
print(subkey[1])
print(subkey[2])

7.2.3. Serialize model parameters

Normally, the model parameters are represented by a nested dictionary like this:

{
    "embedding": DeviceArray,
    "ff1": {
        "kernel": DeviceArray,
        "bias": DeviceArray
    },
    "ff2": {
        "kernel": DeviceArray,
        "bias": DeviceArray
    }
}

You can use flax.serialization.msgpack_serialize to serialize the parameters into bytes, and use flax.serialization.msgpack_restore to convert them back.

7.2.4. Convertion between NumPy array and JAX array

Use np.asarray and onp.asarray.

import jax.numpy as np
import numpy as onp

a = np.array([1, 2, 3])  # JAX array
b = onp.asarray(a)  # converted to NumPy array

c = onp.array([1, 2, 3])  # NumPy array
d = np.asarray(c)  # converted to JAX array

7.2.5. Type annotation

np.ndarray

7.2.6. Check if an array is either a NumPy array or a JAX array

isinstance(a, (np.ndarray, onp.ndarray))

7.2.7. Get the shapes of all parameters in a nested dictionary

jax.tree_map(lambda x: x.shape, params)

8. Confusing Syntax

8.1. What is a[:, None]?

np.newaxis

8.2. How to understand np.einsum?

9. Common Gotchas

9.1. About TPU

9.1.1. External IP of TPU machine changes occasionally

As of 17 Feb 2022, the external IP address may change if there is a maintenance event. If this happens, you need to reconnect with the new IP address.

9.1.2. One TPU device can only be used by one process at a time

Unlike GPU, you will get an error if you run two processes on TPU at a time:

I0000 00:00:1648534265.148743  625905 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.

Even if a TPU device has 8 cores and one process only utilizes the first core, the other processes will not be able to utilize the rest of the cores.

9.1.3. TCMalloc breaks several programs

TCMalloc is Google's customized memory allocation library. On TPU VM, LD_PRELOAD is set to use TCMalloc by default:

$ echo LD_PRELOAD
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4

However, using TCMalloc in this manner may break several programs like gsutil:

$ gsutil --help
/snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/python3: /snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/../../../lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found (required by /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4)

The homepage of TCMalloc also indicates that LD_PRELOAD is tricky and this mode of usage is not recommended.

If you encounter problems related to TCMalloc, you can disable it in the current shell using the command:

unset LD_PRELOAD

9.1.4. There is no TPU counterpart of nvidia-smi

See google/jax#9756.

9.2. About JAX

9.2.1. Indexing an array with an array

import jax.numpy as np
import numpy as onp

a = onp.arange(12).reshape((6, 2))
b = onp.arange(6).reshape((2, 3))

a_ = np.asarray(a)
b_ = np.asarray(b)

a[b]  # success
a_[b_]  # success
a_[b]  # success
a[b_]  # error: index 3 is out of bounds for axis 1 with size 2

Generally speaking, JAX supports NumPy arrays, but NumPy does not support JAX arrays.

9.2.2. np.dot and torch.dot are different

import numpy as onp
import torch

a = onp.random.rand(3, 4, 5)
b = onp.random.rand(4, 5, 6)
onp.dot(a, b)  # success

a_ = torch.from_numpy(a)
b_ = torch.from_numpy(b)
torch.dot(a_, b_)  # error: 1D tensors expected, but got 3D and 3D tensors

9.2.3. np.std and torch.std are different

import torch

x = torch.tensor([[-1., 1.]])

print(x.std(-1).numpy())  # [1.4142135]
print(x.numpy().std(-1))  # [1.]

This is because in np.std the denominator is n, while in torch.std it is n-1. See pytorch/pytorch#1854 for details.

9.2.4. Computations on TPU are in low precision by default

JAX uses bfloat16 for matrix multiplication on TPU by default, even if the data type is float32.

import jax.numpy as np

print(4176 * 5996)  # 25039296

a = np.array(0.4176, dtype=np.float32)
b = np.array(0.5996, dtype=np.float32)
print((a * b).item())  # 0.25039297342300415

To do matrix multiplication in float32, you need to add this line at the top of the script:

jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)

Other precision values can be found in jax.lax.Precision. See google/jax#9973 for details.

9.2.5. Weight matrix of linear layer is transposed in PyTorch

Weight matrix of linear layer is transposed in PyTorch, but not in Flax. Therefore, if you want to convert model parameters between PyTorch and Flax, you needed to transpose the weight matrices.

In Flax:

import flax.linen as nn
import jax.numpy as np
import jax.random as rand
linear = nn.Dense(5)
key = rand.PRNGKey(42)
params = linear.init(key, np.zeros((3,)))
print(params['params']['kernel'].shape)  # (3, 5)

In PyTorch:

import torch.nn as nn
linear = nn.Linear(3, 5)
print(linear.weight.shape)  # (5, 3), not (3, 5)

About

Everything you want to know about Google Cloud TPU

License:Creative Commons Zero v1.0 Universal


Languages

Language:Python 100.0%