Associative scan with non-scalar inputs
yolky opened this issue · comments
Is there any plan to extend the associate scan to work with non-scalar inputs? At the moment the associative scan with an input of size [L, D] performs D independent associative scans of length L, each acting on scalars. It would useful if the combine_fn could support tensor inputs. This would be useful, for example, for prefix sums with matrix-vector products.