tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.

Home Page:https://burn.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[WGPU Backend] Weird burn-jit/fusion/kernel panic bug in latest commit

jclmnop opened this issue · comments

Describe the bug

Today I updated from 0.13 to 0.14 (commit #98a58c8) because I needed to use the new logger functionality for the learner. I'm using the WGPU backend

The following code, which worked fine in 0.13, started to panic:

impl<B: Backend, const D: usize> TensorExtFloat<B, D> for Tensor<B, D, Float> {
    // . . .

    fn nan_quantile(self, q: f64) -> Tensor<B, D, Float> {
        assert!(q >= 0.0 && q <= 1.0, "Quantile must be between 0 and 1");
        let input_shape = self.dims();
        let device = self.device();
        let tensor = self.sort(D - 1);

        let nans = tensor.clone().count_nans();
        let not_nans: Tensor<B, D, Int> =
            Tensor::full(nans.shape(), input_shape[D - 1] as i32, &device) - nans;
        let single_elems = not_nans.clone().equal_elem(1);

        let position = (not_nans.float() - 1.0) * q;
        let low = position.clone().int();
        let high = (low.clone() + 1).mask_fill(single_elems, 0);
        let weight = position.clone() - low.clone().float();

        let low_vals = tensor.clone().gather(D - 1, low);
        let high_vals = tensor.gather(D - 1, high);

        low_vals * (weight.ones_like() - weight.clone()) + high_vals * weight
    }

   // . . .
}

The panic:

---- utils::tensor_ext::tests::test_nan_quantile stdout ----
thread 'utils::tensor_ext::tests::test_nan_quantile' panicked at /Users/jclmnop/.cargo/git/checkouts/burn-178c6829f420dae1/98a58c8/crates/burn-jit/src/fusion/kernel.rs:215:47:
range end index 1 out of range for slice of length 0
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

I managed to isolate the cause of the panic down to the following two lines:

let high = (low.clone() + 1).mask_fill(single_elems, 0);
let weight = position.clone() - low.clone().float();

If I remove either of those lines, the panic goes away. I don't know if it has something to do with using mask_fill() on the tensor beforehand, or if that's just a coincidence and something else is causing it.

To Reproduce
Here's a minimum reproducible example:

#[cfg(test)]
mod tests {
    use burn::prelude::*;
    type TestBackend = burn::backend::Wgpu<f32, i32>;

    // Panics
    #[test]
    fn test_bug_1() {
        let device = Default::default();
        let data = TensorData::from([0.5]);

        let data_bool = TensorData::from([false]);
        let mask = Tensor::<TestBackend, 1, Bool>::from_data(data_bool, &device);

        let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
        let tensor_int = tensor.clone().int();
        let tensor_masked = (tensor_int.clone() + 1).mask_fill(mask, 0);
        let tensor_float = tensor_int.clone().float();
    }

    // Passes
    #[test]
    fn test_bug_2() {
        let device = Default::default();
        let data = TensorData::from([0.5]);

        let data_bool = TensorData::from([false]);
        // let mask = Tensor::<TestBackend, 1, Bool>::from_data(data_bool, &device);

        let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
        let tensor_int = tensor.clone().int();
        // let tensor_masked = (tensor_int.clone() + 1).mask_fill(mask, 0);
        let tensor_float = tensor_int.clone().float();
    }

    // Passes
    #[test]
    fn test_bug_3() {
        let device = Default::default();
        let data = TensorData::from([0.5]);

        let data_bool = TensorData::from([false]);
        let mask = Tensor::<TestBackend, 1, Bool>::from_data(data_bool, &device);

        let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
        let tensor_int = tensor.clone().int();
        let tensor_masked = (tensor_int.clone() + 1).mask_fill(mask, 0);
        // let tensor_float = tensor_int.clone().float();
    }
    
    // Passes
    #[test]
    fn test_bug_4() {
        let device = Default::default();
        let data = TensorData::from([0.5]);

        let data_bool = TensorData::from([false]);
        let mask = Tensor::<TestBackend, 1, Bool>::from_data(data_bool, &device);

        let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
        let tensor_int = tensor.clone().int();
        
        // Swapped the order of these two lines
        let tensor_float = tensor_int.clone().float();
        let tensor_masked = (tensor_int.clone() + 1).mask_fill(mask, 0);
    }
}

Expected behavior
All four tests should pass without panicking. But test_bug_1 panics while the others are fine.

Desktop (please complete the following information):

  • OS: MacOS Sonoma 14.5
  • Chip: M1 Pro

Additional context
I've only tested this with the WGPU backend, and am not sure whether it's present on any other backends.

If I learn anything more about this bug (if it is a bug and not just me missing something obvious) while I try to work around it for my own code, I'll update this issue.

Update: Just tested with ndarray backend and it's fine (no panic).

Thanks for the MWE! This will help.

I'll investigate to see what's the cause.

Thought a recent PR might have been the root of the issue but if it works on ndarray (and I checked with a commit just before said PR), so it's something else.

/edit: seems to be the tensor + 1 elementwise op that fails on wgpu.

@jclmnop should be fixed with the linked PR. If you wanna give it a shot before it merges you can check out the branch.

@jclmnop should be fixed with the linked PR. If you wanna give it a shot before it merges you can check out the branch.

Nice one thanks, I'll give it a go either tonight or tomorrow.

@laggui Just tested it out with the linked PR branch and it works fine, thanks for that. I'll just keep burn patched with that branch until it's merged.

(will leave this issue open so it can get closed by the PR)