google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to restore a variable from checkpoint saved in cpu back in cpu when you have both gpu and cpu?

PriyeshV opened this issue · comments

I get the following error when I try to restore,

ValueError: SingleDeviceSharding with Device=TFRT_CPU_0 was not found in jax.local_devices().

Despite enclosing the statements within a CPU device scope, like below, the visible device is only cuda and not CPU.

        with jax.default_device(jax.devices('cpu')[0]):
            print(jax.devices())
            print(jax.local_devices())
            variables = Mngr.restore(start_step)
 

Could you give me some pointers on how to handle this?

PS: This is for a DQN code where I'm trying to save the replay buffer (FlashBAX) from the CPU and network parameters from the GPU. I saved the buffer and parameters, but restoration has been an issue.

Thank You

You might have saved an array with a sharding that is incompatible with your current device setup inside the CPU scope. You'll need to specify restore_args to communicate the sharding that you want for each array in the tree.

Hi,

I don't have any sharding settings specified for the variable.
Below is the entirety of the code.
PS: I'm running on a machine with a GPU

import jax
import jax.numpy as jnp
import chex
import flashbax as fbx
import orbax.checkpoint as ocp


@chex.dataclass(frozen=True)
class TimeStep:
    observation: chex.Array
    action: chex.Array
    reward: chex.Array
    done: chex.Array


with jax.default_device(jax.devices('cpu')[0]):
    rb = fbx.make_flat_buffer(max_length=10000, min_length=1000,
                              sample_batch_size=512, add_sequences=False, add_batch_size=None)
    rb = rb.replace(init=jax.jit(rb.init), add=jax.jit(rb.add, donate_argnums=0), sample=jax.jit(rb.sample),
                    can_sample=jax.jit(rb.can_sample))
    dummy_timestep = TimeStep(observation=jnp.ones((84, 84, 4), dtype=jnp.uint8), action=jnp.int32(0),
                              reward=jnp.float32(0.0), done=jnp.bool_(True))
    rb_state = rb.init(dummy_timestep)

mngr_options = ocp.CheckpointManagerOptions(max_to_keep=1, save_interval_steps=1)
Mngr = ocp.CheckpointManager('/home/mila/v/vijayanp/Test', {'rb_state': ocp.PyTreeCheckpointer()}, mngr_options)

Mngr.save(0, {'rb_state': rb_state})
Mngr.wait_until_finished()

rb_variables = Mngr.restore(Mngr.latest_step())


Example if you're trying to restore on device.

def make_restore_arg(arr):
  return ocp.ArrayRestoreArgs(sharding=...)

restore_args = jax.tree_util.tree_map(make_restore_arg, rb_state)
Mngr.restore(Mngr.latest_step(), restore_kwargs={'rb_state': {'restore_args': restore_args})

I'm unclear on what you want to do exactly. Maybe you want to restore in CPU memory (as numpy arrays).

def make_restore_arg(arr):
  return ocp.RestoreArgs(restore_type=np.ndarray)

restore_args = jax.tree_util.tree_map(make_restore_arg, rb_state)
Mngr.restore(Mngr.latest_step(), restore_kwargs={'rb_state': {'restore_args': restore_args})

I'm sorry for being unclear. I'll try to explain again, if you don't mind.

Objective: Load a variable (originally in CPU) into CPU memory.
Issue: When I call restore, it tries to load the CPU object but in vain. It throws the following error,
ValueError: SingleDeviceSharding with Device=TFRT_CPU_0 was not found in jax.local_devices().

My understanding:

  • Orbax looks for CPU in jax.local_devices() instead of jax.devices() to restore, but CPU is unavailable there.

I think you do want to do this then:

def make_restore_arg(arr):
  return ocp.RestoreArgs(restore_type=np.ndarray)

restore_args = jax.tree_util.tree_map(make_restore_arg, rb_state)
Mngr.restore(Mngr.latest_step(), restore_kwargs={'rb_state': {'restore_args': restore_args})

I think what's happening is that with jax.default_device(jax.devices('cpu')[0]) creates arrays with SingleDeviceSharding(device=TFRT_CPU_0). This gets recorded in the sharding metadata in the checkpoint. When you don't provide restore_args with sharding property specified, it tries to use the sharding metadata to restore the arrays. But I guess because you're not including the with again, it's not able to reconstruct the original sharding as recorded in the metadata.

To restore, you would need to either:

  1. Include the with (not completely sure if this would work or not)
  2. Provide restore_args such that the same sharding as was used to save is used to restore
  3. Provide restore_type=np.ndarray to restore as numpy array in memory.