flashlight / flashlight

A C++ standalone library for machine learning

Home Page:https://fl.readthedocs.io/en/latest/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

keep_dims not respected when a reduction is over the entire input shape

bwasti opened this issue · comments

Bug Description

As in title

if (isAllAxisReduction(input, axes)) {
// Reduce along all axes returning a singleton tensor
// TODO: modify this to af::mean<af::array> to take advantage of the
// ArrayFire reduce_all kernels once available
return toTensor<ArrayFireTensor>(
detail::condenseIndices(
af::mean(af::mean(af::mean(af::mean(toArray(input)))))),
/* numDims = */ 0);
} else {

(current workaround in shumai: https://github.com/facebookresearch/shumai/blob/main/shumai/cpp/binding_gen.inl#L454-L458)