coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ConcatAlong is very slow

opfromthestart opened this issue · comments

With using (a,b).concat_along(Axis::<1>) each inference takes about 300ms, while if I manually convert the tensors to vectors, concat them myself, and then convert them back into tensors, it only takes about 6ms. I checked that the outputs are the same so I didn't miss anything and they have the same output. I'm not sure what the problem is but I do not think this is acceptable.
This is using the CUDA version.

Can you give shapes of a & b tensors? I'll take a look

Also is this after repeated calls? ConcatAlong JIT compiles the kernel the first execution which probably takes somewhere along those lines

Yeah this is from JIT

I get these timings from this simple example:

use std::time::Instant;

use dfdx::prelude::*;

fn main() {
    let dev: AutoDevice = Default::default();
    let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2));
    let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
    for _ in 0..10 {
        let start = Instant::now();
        let _: Tensor<Rank2<2, 6>, f32, _> =
            (a.clone(), b.clone()).concat_along(Axis::<1>).realize();
        println!("{:?}", start.elapsed());
    }
}
375.891925ms
20.9µs
11.16µs
11.64µs
9.771µs
9.75µs
9.27µs
9.43µs
9.73µs
9.68µs

Going to close this for now - If you want to investigate if there are ways to speed up JIT compilation times feel free to open a different issue!

The shapes I am concating are (Const::<1>, Const::<6240>) with (Const::<1>, Const::<6>). Even in a loop, it appears to always have this cost, so I think it may be recompiling something every time, either that or it is not ever trying to optimize it.

Doing the same test that you did for my network, it appears to consistently have the same time value regardless of the number of calls. Is it possible for the function to have been cleared out of the GPU cache, and how would I prevent that?
Network times:

310.782126ms
295.890484ms
291.42786ms
292.147541ms
291.059603ms
303.778666ms
293.815021ms
310.522738ms
301.337522ms
308.999231ms
322.108602ms
321.610593ms
329.548947ms
312.1137ms
344.742673ms
296.267473ms
292.671337ms
291.348365ms
300.508978ms
328.355806ms

Times for CPU hover around 16ms.

I added some debug statements into the ConcatAlong kernel and

        if !self.dev.has_func(&module_name, "fwd") {

will always return false, and so it recompiles in every single loop. Is there a way to prevent this?

I added my own caching using a once_cell:sync::Lazy and I got an iteration down to 5ms. I think some transparency in how kernels are loaded and unloaded would help.

Ah hmm good info. Definitely should not be compiled every time. What dtype are you using? And can you send a simple snippet that reproduces the behavior?

Are you recreating the device inside the loop or whatever calls the concat_along? Each device instantiation will need to recompile the kernels

I was unaware of that. I was creating a new Cuda object every time in the loop.
I had thought that it was a zero sized/marker type, I didn't know it was important.

Should probably move to that TBH, this is fairly common conception.