google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Memory issue when randomly initializing large parameters, sharding cannot help

hr0nix opened this issue · comments

Description

This is a duplicate of this issue, but I think it might be flax-related, so also posting here.

Consider the following code snippet:

import jax
import flax.linen as nn
from jax.sharding import Mesh
import functools


class Model(nn.Module):
    output_dim = 32768 * 8

    @nn.compact
    def __call__(self, inputs):
        block = nn.Dense(features=self.output_dim, use_bias=False)
        return block(inputs)


class ShardedModel(nn.Module):
    output_dim = 32768 * 8

    @nn.compact
    def __call__(self, inputs):
        init_fn = nn.initializers.lecun_normal()
        # init_fn = nn.initializers.zeros
        block = nn.Dense(
            features=self.output_dim,
            use_bias=False,
            kernel_init=nn.with_logical_partitioning(
                init_fn, ("logical_axis", "unmodelled")
            ),
        )
        return block(inputs)


def test_model(model: nn.Module):
    key = jax.random.PRNGKey(0)
    input_shape = (1, 32768)
    inputs = jax.random.normal(key, input_shape)

    devices = jax.devices()
    mesh = Mesh(devices, {"mesh_axis": len(devices)})
    print(f"Device mesh: {mesh}")
    sharding_rules = [
        ("logical_axis", "mesh_axis"),
    ]

    abstract_params = jax.eval_shape(model.init, key, inputs)
    params_partition_spec = nn.get_partition_spec(abstract_params)
    params_sharding = nn.logical_to_mesh_sharding(
        params_partition_spec,
        mesh,
        rules=sharding_rules,
    )
    print(f"Intended sharding: {params_sharding}")

    init_fn = functools.partial(model.init, key, inputs)
    init_fn = jax.jit(
        init_fn,
        out_shardings=params_sharding,
    )
    params = init_fn()

    actual_sharding = jax.tree_util.tree_map(lambda leaf: leaf.sharding, params)
    print(f"Actual sharding: {actual_sharding}")


def main():
    test_model(Model())
    # test_model(ShardedModel())


if __name__ == "__main__":
    main()

Here I'm trying to initialize a very large dense layer. Despite the layer weights requiring only 32Gb of RAM (and I'm running on 80Gb H100), this code will fail because jax will try to simultaneously allocate quite a few buffers for RNG keys so that the total memory consumption is 112 Gb!

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 85899347204 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         0B
              constant allocation:         8B
        maybe_live_out allocation:   32.00GiB
     preallocated temp allocation:   80.00GiB
  preallocated temp fragmentation:       124B (0.00%)
                 total allocation:  112.00GiB
              total fragmentation:   16.00GiB (14.29%)
Peak buffers:
        Buffer 1:
                Size: 32.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/mul" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: f32[32768,262144]
                ==========================

        Buffer 2:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 3:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 4:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 5:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 6:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

Is this intended? Do we really need to store these buffers in memory simultaneously to initialize the layer?

In any case, we can try to fix the problem by sharding the layer over the available devices (8x80Gb H100, comment line 66 and uncomment line 67 in the code above). Interestingly, while this change reduces the size of the parameter tensor as intended, rng buffers are still being allocated in full!

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 103079216656 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         0B
              constant allocation:        40B
        maybe_live_out allocation:    4.00GiB
     preallocated temp allocation:   96.00GiB
  preallocated temp fragmentation:       124B (0.00%)
                 total allocation:  100.00GiB
              total fragmentation:    4.00GiB (4.00%)
Peak buffers:
        Buffer 1:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 2:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 3:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 4:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 5:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 6:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 7:
                Size: 4.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/mul" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: f32[4096,262144]
                ==========================

This seems to be a bug: why is jax trying to materliaze the full rng tensor on each shard if it's not needed in full there?

Finally, if I use all zeros initialization (uncomment line 22 in the code above), the issue goes away.

So, to summarize, I have the following questions:

  • Is it expected that jax will try to simultaneously allocate so many rng buffers for weight initialization?
  • Why does the rng buffer allocation not respect sharding?
  • Are there any workarounds I can use to achieve what I need without resorting to all-zeros initialization?

The example above, while artificial, is inspired by a real problem that we've encountered while trying to initialize a large model.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.24.3
flax:  0.8.1
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1

$ nvidia-smi
Tue Feb 20 16:53:25 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   34C    P0             117W / 700W |    539MiB / 81559MiB |      1%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   30C    P0             113W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   33C    P0             112W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   30C    P0             115W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   34C    P0             119W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   30C    P0             113W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   33C    P0             112W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   30C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

Closing this issue since the solution in the duplicated issue seems to work.