Avoid casting with `numpy()` in `multiprocessing.py`
Peiffap opened this issue · comments
Outline & Motivation
Currently, get_extra_results()
casts callback metrics to numpy to avoid problems with memory sharing:
Then
update_main_process_results()
casts back to Tensor
:It would be neater (and part of a greater goal of not depending on the numpy package, see #16649) to avoid this trick.
Pitch
Remove the cast to numpy without introducing errors. Remove the numpy depencency in multiprocessing.py
.
Additional context
I visited the code of get_extra_results()
. It solely converts the torch tensors to detached numpy arrays. If we remove usage of numpy in it, then it means that function reduces to below body :
def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
avoid issues with memory sharing, we cast the data to numpy.
Args:
trainer: reference to the Trainer.
Returns:
A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
process this output.
"""
return {"callback_metrics": trainer.callback_metrics}
I think we can use simple Python instances like list , dict , etc ... to avoid using numpy ....
For that case, the better function to work with would be apply_to_collection
Pls share your thoughts....
Thanks