Error loading loop from a checkpoint
nwbvt opened this issue · comments
Nicholas Brown commented
Description
I'm getting an error while trying to load a training loop from a checkpoint
File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
95 loss_layer=loss,
96 optimizer=optimizer,
97 lr_schedule=schedule,
98 n_steps_per_checkpoint=500)
100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
101 metrics=metrics,
102 n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)
File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
289 layer.weights, layer.state = tl.on_cpu(self._unreplicate(
290 _make_weights_and_state_same_across_hosts(
291 self._for_n_devices(weights_and_state))))
293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
296 # Prepare eval components.
297 self._eval_at = eval_at or default_at
File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
941 matched_flat_slots = _match_by_shape(
942 self._to_bits(_flatten_and_remove_empty(trainer.slots)),
943 _flatten_and_remove_empty(slots))
--> 944 matched_slots, _ = fastmath.tree_unflatten(
945 self._from_bits(matched_flat_slots),
946 trainer.slots, copy_from_tree=[None, ()])
947 trainer.slots = matched_slots
948 self._step = d['step']
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
216 def tree_unflatten(flat, tree, copy_from_tree=None):
217 """Unflatten a list into a tree given the tree shape as second argument.
218
219 Args:
(...)
237 more were provided than the number of leaves of tree (useful for recursion).
238 """
--> 239 if copy_from_tree is not None and tree in copy_from_tree:
240 return tree, flat
241 if isinstance(tree, (list, tuple)):
File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
257 # Note: don't use isinstance here, because we don't want to raise for
258 # subclasses, e.g. NamedTuple objects that may override operators.
259 if type(other) in _rejected_binop_types:
--> 260 raise TypeError(f"unsupported operand type(s) for {opchar}: "
261 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
262 return NotImplemented
TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'
Environment information
OS: Ubuntu
$ pip freeze | grep trax
trax==1.4.1
$ pip freeze | grep tensor
tensorboard==2.14.1
tensorboard-data-server==0.7.1
tensorflow==2.14.0
tensorflow-datasets==4.9.3
tensorflow-estimator==2.14.0
tensorflow-hub==0.15.0
tensorflow-io-gcs-filesystem==0.34.0
tensorflow-metadata==1.14.0
tensorflow-text==2.14.0
$ pip freeze | grep jax
jax==0.4.19
jaxlib==0.4.19+cuda12.cudnn89
$ python3 -V
Python 3.10.12
For bugs: reproduction and error logs
# Steps to reproduce:
Create a training loop, give it an output dir, and then train the model.
Then create a new training loop, give it that same output dir. It will try to load from a checkpoint but fail.
# Error logs:
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[57], line 1
----> 1 training_loop = make_trainer(model, train_data, test_data,
2 trax.layers.Fn("ScaledLoss", ScaledLoss),
3 [trax.layers.Fn("ScaledLoss", ScaledLoss),
4 trax.layers.Fn("InRange", InRange),
5 trax.layers.Fn("ClippedMAE", ClippedMAE)],
6 "model_dff=4096")
File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
95 loss_layer=loss,
96 optimizer=optimizer,
97 lr_schedule=schedule,
98 n_steps_per_checkpoint=500)
100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
101 metrics=metrics,
102 n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)
File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
289 layer.weights, layer.state = tl.on_cpu(self._unreplicate(
290 _make_weights_and_state_same_across_hosts(
291 self._for_n_devices(weights_and_state))))
293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
296 # Prepare eval components.
297 self._eval_at = eval_at or default_at
File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
941 matched_flat_slots = _match_by_shape(
942 self._to_bits(_flatten_and_remove_empty(trainer.slots)),
943 _flatten_and_remove_empty(slots))
--> 944 matched_slots, _ = fastmath.tree_unflatten(
945 self._from_bits(matched_flat_slots),
946 trainer.slots, copy_from_tree=[None, ()])
947 trainer.slots = matched_slots
948 self._step = d['step']
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
216 def tree_unflatten(flat, tree, copy_from_tree=None):
217 """Unflatten a list into a tree given the tree shape as second argument.
218
219 Args:
(...)
237 more were provided than the number of leaves of tree (useful for recursion).
238 """
--> 239 if copy_from_tree is not None and tree in copy_from_tree:
240 return tree, flat
241 if isinstance(tree, (list, tuple)):
File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
257 # Note: don't use isinstance here, because we don't want to raise for
258 # subclasses, e.g. NamedTuple objects that may override operators.
259 if type(other) in _rejected_binop_types:
--> 260 raise TypeError(f"unsupported operand type(s) for {opchar}: "
261 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
262 return NotImplemented
TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'
Nicholas Brown commented
Note I can load the model itself with model.init_from_file(f"{outputdir}/model.pkl.gz") just fine. It's only when the training loop is created where it fails.