Implement broadcast_all() in C/C++
fritzo opened this issue · comments
#The distributions.utils.broadcast_all()
is used in all distribution constructors to clean up parameters. This is expensive. It would be nice to make this cheaper by reimplementing in C/C++.
Adam suggested using torch._C._infer_size()
under the hood but we would need to add more support for scalars.
It would also be good to implement Distribution._validate_log_prob_arg()
in C/C++.