Numerical differences between shardings in random algorithm
shawnwang18 opened this issue · comments
Description
We are seeing numerical differences between shardings in random number initialization on GPUs. For example, if I have a mesh of DP, FSDP, TP , based on what no of devices I allocate to each of these axes the numerical output of my initialization changes drastically. As a result of this when we are using TP we are seeing divergences in the network.
System info (python version, jaxlib version, accelerator, etc.)
`jax: 0.4.27.dev20240514
jaxlib: 0.4.27.dev20240420
numpy: 1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ipp1-2023.nvidia.com', release='5.15.0-88-generic', version='#98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023', machine='x86_64')
$ nvidia-smi
Tue May 14 23:37:29 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A30 On | 00000000:01:00.0 Off | 0 |
| N/A 27C P0 31W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A30 On | 00000000:23:00.0 Off | 0 |
| N/A 27C P0 32W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA A30 On | 00000000:41:00.0 Off | 0 |
| N/A 28C P0 33W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA A30 On | 00000000:61:00.0 Off | 0 |
| N/A 26C P0 32W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA A30 On | 00000000:81:00.0 Off | 0 |
| N/A 27C P0 34W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA A30 On | 00000000:A1:00.0 Off | 0 |
| N/A 28C P0 33W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA A30 On | 00000000:C1:00.0 Off | 0 |
| N/A 28C P0 32W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA A30 On | 00000000:E1:00.0 Off | 0 |
| N/A 28C P0 33W / 165W | 234MiB / 24576MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 41 C python 0MiB |
| 1 N/A N/A 41 C python 0MiB |
| 2 N/A N/A 41 C python 0MiB |
| 3 N/A N/A 41 C python 0MiB |
| 4 N/A N/A 41 C python 0MiB |
| 5 N/A N/A 41 C python 0MiB |
| 6 N/A N/A 41 C python 0MiB |
| 7 N/A N/A 41 C python 0MiB |
+-----------------------------------------------------------------------------------------+
``
The re-produce unittest is as below, it is required to run on a node with 8GPUs
`import jax.numpy as jnp
import jax
from jax.experimental import mesh_utils as jax_mesh_utils
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
MESH_DATA_AXIS = 'data'
MESH_TENSOR_AXIS = 'tensor'
MESH_FSDP_AXIS="pipeline"
# create an FSDP mesh
ici_mesh = (2, 4, 1) # DP, FSDP, TP
dcn_mesh = (1, 1, 1) # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_mesh.shape) # (2, 8, 1)
# create an FSDP, TP mesh
ici_mesh = (1, 4, 2) # DP, FSDP, TP
dcn_mesh = (1, 1, 1) # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_tp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_tp_mesh.shape) # (1, 4, 4)
# create an FSDP, TP, DP mesh
ici_mesh = (2, 2, 2) # DP, FSDP, TP
dcn_mesh = (1, 1, 1) # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_tp_dp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_tp_dp_mesh.shape) # (2, 2, 4)
# generate the data
batch_size = 32
seq_len = 8192
n_heads = 32
head_dim = 128
emb_dim = 4096
DATA_SUBMESH = (MESH_DATA_AXIS, MESH_FSDP_AXIS)
def gen_data_fn():
key = jax.random.PRNGKey(43)
scale = 0.05
activations = scale * jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16)
weights = scale * jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16)
return activations, weights
data_fn = pjit(
gen_data_fn,
out_shardings=(P(DATA_SUBMESH, None, MESH_TENSOR_AXIS), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)),
)
# fsdp utputs
with fsdp_mesh:
act1, weights1 = data_fn()
with fsdp_tp_mesh:
act2, weights2 = data_fn()
with fsdp_tp_dp_mesh:
act3, weights3 = data_fn()
# diff b/w fsdp and fsdp,tp
def get_diffs(x, y):
abs_diff = jnp.abs(x - y)
max_difference = round(jnp.max(abs_diff), 5)
min_difference = round(jnp.min(abs_diff), 5)
avg_difference = round(jnp.mean(abs_diff), 5)
return max_difference, min_difference, avg_difference
max_diff, min_diff, avg_diff = jax.jit(get_diffs)(act1, act2)
print(f"Differences b/w FSDP and FSDP,TP: Max -- {max_diff}, Min -- {min_diff}, Average -- {avg_diff}")
max_diff, min_diff, avg_diff = jax.jit(get_diffs)(act1, act3)
print(f"Differences b/w FSDP and FSDP,TP,DP: Max -- {max_diff}, Min -- {min_diff}, Average -- {avg_diff}")
`
This is fixed by upgrading to partitionable threefry, e.g. by adding the following line to the top of the file (after imports):
jax.config.update('jax_threefry_partitionable', True)
See #18480 for more on the upgrade (which was delayed a bit, but is still planned).
IIUC this is a bug (unintended behavior) even with jax_threefry_partitionable=False, and also we don't yet know what's causing this bug. Good to know that setting jax_threefry_partitionable=True fixes it though!
Yes, I consider it a bug as well, but still undiagnosed.