google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Numerical differences between shardings in random algorithm

shawnwang18 opened this issue · comments

Description

We are seeing numerical differences between shardings in random number initialization on GPUs. For example, if I have a mesh of DP, FSDP, TP , based on what no of devices I allocate to each of these axes the numerical output of my initialization changes drastically. As a result of this when we are using TP we are seeing divergences in the network.

System info (python version, jaxlib version, accelerator, etc.)

`jax:    0.4.27.dev20240514
jaxlib: 0.4.27.dev20240420
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ipp1-2023.nvidia.com', release='5.15.0-88-generic', version='#98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023', machine='x86_64')


$ nvidia-smi
Tue May 14 23:37:29 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A30                     On  |   00000000:01:00.0 Off |                    0 |
| N/A   27C    P0             31W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A30                     On  |   00000000:23:00.0 Off |                    0 |
| N/A   27C    P0             32W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A30                     On  |   00000000:41:00.0 Off |                    0 |
| N/A   28C    P0             33W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A30                     On  |   00000000:61:00.0 Off |                    0 |
| N/A   26C    P0             32W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A30                     On  |   00000000:81:00.0 Off |                    0 |
| N/A   27C    P0             34W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A30                     On  |   00000000:A1:00.0 Off |                    0 |
| N/A   28C    P0             33W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A30                     On  |   00000000:C1:00.0 Off |                    0 |
| N/A   28C    P0             32W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A30                     On  |   00000000:E1:00.0 Off |                    0 |
| N/A   28C    P0             33W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A        41      C   python                                          0MiB |
|    1   N/A  N/A        41      C   python                                          0MiB |
|    2   N/A  N/A        41      C   python                                          0MiB |
|    3   N/A  N/A        41      C   python                                          0MiB |
|    4   N/A  N/A        41      C   python                                          0MiB |
|    5   N/A  N/A        41      C   python                                          0MiB |
|    6   N/A  N/A        41      C   python                                          0MiB |
|    7   N/A  N/A        41      C   python                                          0MiB |
+-----------------------------------------------------------------------------------------+
``

The re-produce unittest is as below, it is required to run on a node with 8GPUs

`import jax.numpy as jnp
import jax
from jax.experimental import mesh_utils as jax_mesh_utils
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

MESH_DATA_AXIS = 'data'
MESH_TENSOR_AXIS = 'tensor'
MESH_FSDP_AXIS="pipeline"


# create an FSDP mesh
ici_mesh = (2, 4, 1)  # DP, FSDP, TP
dcn_mesh = (1, 1, 1)  # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_mesh.shape)  # (2, 8, 1)

# create an FSDP, TP mesh
ici_mesh = (1, 4, 2)  # DP, FSDP, TP
dcn_mesh = (1, 1, 1)  # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_tp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_tp_mesh.shape)  # (1, 4, 4)

# create an FSDP, TP, DP mesh
ici_mesh = (2, 2, 2)  # DP, FSDP, TP
dcn_mesh = (1, 1, 1)  # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_tp_dp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_tp_dp_mesh.shape)  # (2, 2, 4)

# generate the data
batch_size = 32
seq_len = 8192
n_heads = 32
head_dim = 128
emb_dim = 4096
DATA_SUBMESH = (MESH_DATA_AXIS, MESH_FSDP_AXIS)

def gen_data_fn():
    key = jax.random.PRNGKey(43)
    scale = 0.05
    activations = scale * jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16)
    weights = scale * jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16)
    return activations, weights

data_fn = pjit(
    gen_data_fn,
    out_shardings=(P(DATA_SUBMESH, None, MESH_TENSOR_AXIS), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)),
)

# fsdp utputs
with fsdp_mesh:
    act1, weights1 = data_fn()

with fsdp_tp_mesh:
    act2, weights2 = data_fn()

with fsdp_tp_dp_mesh:
    act3, weights3 = data_fn()

# diff b/w fsdp and fsdp,tp
def get_diffs(x, y):
    abs_diff = jnp.abs(x - y)
    max_difference = round(jnp.max(abs_diff), 5)
    min_difference = round(jnp.min(abs_diff), 5)
    avg_difference = round(jnp.mean(abs_diff), 5)
    return max_difference, min_difference, avg_difference

max_diff, min_diff, avg_diff = jax.jit(get_diffs)(act1, act2)
print(f"Differences b/w FSDP and FSDP,TP: Max -- {max_diff}, Min -- {min_diff}, Average -- {avg_diff}")

max_diff, min_diff, avg_diff = jax.jit(get_diffs)(act1, act3)
print(f"Differences b/w FSDP and FSDP,TP,DP: Max -- {max_diff}, Min -- {min_diff}, Average -- {avg_diff}")
`

This is fixed by upgrading to partitionable threefry, e.g. by adding the following line to the top of the file (after imports):

jax.config.update('jax_threefry_partitionable', True)

See #18480 for more on the upgrade (which was delayed a bit, but is still planned).

IIUC this is a bug (unintended behavior) even with jax_threefry_partitionable=False, and also we don't yet know what's causing this bug. Good to know that setting jax_threefry_partitionable=True fixes it though!

Yes, I consider it a bug as well, but still undiagnosed.