probtorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Home Page:http://pytorch.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implement broadcast_all() in C/C++

fritzo opened this issue · comments

#The distributions.utils.broadcast_all() is used in all distribution constructors to clean up parameters. This is expensive. It would be nice to make this cheaper by reimplementing in C/C++.

Adam suggested using torch._C._infer_size() under the hood but we would need to add more support for scalars.

It would also be good to implement Distribution._validate_log_prob_arg() in C/C++.