google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error loading loop from a checkpoint

nwbvt opened this issue · comments

Description

I'm getting an error while trying to load a training loop from a checkpoint

File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
     94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
     95                                                 loss_layer=loss,
     96                                                 optimizer=optimizer,
     97                                                 lr_schedule=schedule,
     98                                                 n_steps_per_checkpoint=500)
    100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
    101                                               metrics=metrics,
    102                                               n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    289       layer.weights, layer.state = tl.on_cpu(self._unreplicate(
    290           _make_weights_and_state_same_across_hosts(
    291               self._for_n_devices(weights_and_state))))
    293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
    296 # Prepare eval components.
    297 self._eval_at = eval_at or default_at

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
    940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
    941   matched_flat_slots = _match_by_shape(
    942       self._to_bits(_flatten_and_remove_empty(trainer.slots)),
    943       _flatten_and_remove_empty(slots))
--> 944   matched_slots, _ = fastmath.tree_unflatten(
    945       self._from_bits(matched_flat_slots),
    946       trainer.slots, copy_from_tree=[None, ()])
    947   trainer.slots = matched_slots
    948 self._step = d['step']

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
    216 def tree_unflatten(flat, tree, copy_from_tree=None):
    217   """Unflatten a list into a tree given the tree shape as second argument.
    218 
    219   Args:
   (...)
    237     more were provided than the number of leaves of tree (useful for recursion).
    238   """
--> 239   if copy_from_tree is not None and tree in copy_from_tree:
    240     return tree, flat
    241   if isinstance(tree, (list, tuple)):

File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    257 # Note: don't use isinstance here, because we don't want to raise for
    258 # subclasses, e.g. NamedTuple objects that may override operators.
    259 if type(other) in _rejected_binop_types:
--> 260   raise TypeError(f"unsupported operand type(s) for {opchar}: "
    261                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
    262 return NotImplemented

TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'

Environment information

OS: Ubuntu

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
tensorboard==2.14.1
tensorboard-data-server==0.7.1
tensorflow==2.14.0
tensorflow-datasets==4.9.3
tensorflow-estimator==2.14.0
tensorflow-hub==0.15.0
tensorflow-io-gcs-filesystem==0.34.0
tensorflow-metadata==1.14.0
tensorflow-text==2.14.0

$ pip freeze | grep jax
jax==0.4.19
jaxlib==0.4.19+cuda12.cudnn89

$ python3 -V
Python 3.10.12

For bugs: reproduction and error logs

# Steps to reproduce:
Create a training loop, give it an output dir, and then train the model.
Then create a new training loop, give it that same output dir. It will try to load from a checkpoint but fail.
# Error logs:
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[57], line 1
----> 1 training_loop = make_trainer(model, train_data, test_data,
      2                              trax.layers.Fn("ScaledLoss", ScaledLoss),
      3                              [trax.layers.Fn("ScaledLoss", ScaledLoss),
      4                               trax.layers.Fn("InRange", InRange), 
      5                               trax.layers.Fn("ClippedMAE", ClippedMAE)],
      6                              "model_dff=4096")

File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
     94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
     95                                                 loss_layer=loss,
     96                                                 optimizer=optimizer,
     97                                                 lr_schedule=schedule,
     98                                                 n_steps_per_checkpoint=500)
    100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
    101                                               metrics=metrics,
    102                                               n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    289       layer.weights, layer.state = tl.on_cpu(self._unreplicate(
    290           _make_weights_and_state_same_across_hosts(
    291               self._for_n_devices(weights_and_state))))
    293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
    296 # Prepare eval components.
    297 self._eval_at = eval_at or default_at

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
    940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
    941   matched_flat_slots = _match_by_shape(
    942       self._to_bits(_flatten_and_remove_empty(trainer.slots)),
    943       _flatten_and_remove_empty(slots))
--> 944   matched_slots, _ = fastmath.tree_unflatten(
    945       self._from_bits(matched_flat_slots),
    946       trainer.slots, copy_from_tree=[None, ()])
    947   trainer.slots = matched_slots
    948 self._step = d['step']

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
    216 def tree_unflatten(flat, tree, copy_from_tree=None):
    217   """Unflatten a list into a tree given the tree shape as second argument.
    218 
    219   Args:
   (...)
    237     more were provided than the number of leaves of tree (useful for recursion).
    238   """
--> 239   if copy_from_tree is not None and tree in copy_from_tree:
    240     return tree, flat
    241   if isinstance(tree, (list, tuple)):

File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    257 # Note: don't use isinstance here, because we don't want to raise for
    258 # subclasses, e.g. NamedTuple objects that may override operators.
    259 if type(other) in _rejected_binop_types:
--> 260   raise TypeError(f"unsupported operand type(s) for {opchar}: "
    261                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
    262 return NotImplemented

TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'

Note I can load the model itself with model.init_from_file(f"{outputdir}/model.pkl.gz") just fine. It's only when the training loop is created where it fails.