[BUG] Tensors are not "modifiable lvalue" in extended lambdas.
sksg opened this issue · comments
Describe the bug
When using tensors inside extended lambdas, the tensor elements cannot be assigned to.
To Reproduce
This minimal example fails:
#include <matx.h>
template<typename FUNC>
__global__ void kernel_launch(FUNC kernel) {
kernel();
}
int main() {
auto a = matx::make_tensor<float>({10});
kernel_launch<<<0, 10>>>([=] __device__() { a(threadIdx.x) = threadIdx.x; }); // <-- error line
}
The build fails with error: expression must be a modifiable lvalue
Expected behavior
The above expression should not fail, and a
should be filled with numbers from 0..9.
System details (please complete the following information):
- OS: Ubuntu 22.04.1 LTS
- CUDA version: 11.8
- g++ version: 11.3.0
Hi @sksg, when you're passing a tensor to a device/global function the tensor must first be converted to tensor_impl_t
. The reason is that tensor_t
maintains some state, like a reference count, that is not usable or wanted in device code. tensor_impl_t
strips a tensor to the bare minimum needed to be used on device code. To convert, you use the base_type
trait:
MatX/include/matx/operators/set.h
Line 66 in d9a32d5
So your code may be something like
int main() {
auto a = matx::make_tensor<float>({10});
auto a_base = base_type<decltype(a)>(a);
kernel_launch<<<0, 10>>>([=] __device__() { a_base(threadIdx.x) = threadIdx.x; }); // <-- error line
}
I have not tried this yet to see if it works, but I will test and get back to you today.
Hi @sksg, the issue is that there are two definitions of operator()
, a const and a non-const version. Since lambdas capture by const, it's choosing the const version that cannot modify an lvalue. Instead you can pass it as a function parameter:
using namespace matx;
template<typename T, typename FUNC>
__global__ void kernel_launch(T t, FUNC kernel) {
kernel(t);
}
kernel_launch<<<0, 10>>>(a_base, [=] __device__(auto ab) {
ab(threadIdx.x) = threadIdx.x; });
And this works as well. Let me know if that suffices.
Thanks for the speedy feedback.
I did not know that lambdas were const by default. It then makes perfect sense. I did try it with a mutable lambda instead:
#include <matx.h>
template<typename FUNC>
__global__ void kernel_launch(FUNC kernel) {
kernel();
}
int main() {
auto a = matx::make_tensor<float>({10});
kernel_launch<<<0, 10>>>([=] __device__() mutable { a(threadIdx.x) = threadIdx.x; }); // <-- No errors!
}
Technically, I can verify that the device array a
has the correct values. But I see your point wrt. matx::base_type
. I do not know if any of the host state will be corrupted by using a
directly rather than a_base
. The above syntax is very concise, though.
Thanks for helping out. This is perfectly usable for me.
Thanks for the report!