tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.

Home Page:https://burn.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Backward pass has mismatched dimensions

galenoshea opened this issue · comments

Describe the bug
I am trying to load an onnx model and get the gradient with respect to the input tensor. The forward pass works fine it breaks at the backward pass.

  • I have tried with Libtorch and Ndarray backend
  • I have tried 3 different onnx models which all work for inference.

`
type B = Autodiff;
let device = ::Device::default();

let model = GlizzyGaze::<B>::new();

let x: Tensor<Autodiff<LibTorch>, 4> = preprocess::<B>("src/images/cat.jpg", &device);

let x: Tensor<Autodiff<LibTorch>, 4> = x.require_grad();
let y: Tensor<Autodiff<LibTorch>, 2> = model.predict(x.clone());
println!("y: {:?}", y.shape());

let grads = y.backward();

let x_grad: Tensor<LibTorch, 4> = x.grad(&grads).unwrap();

`

NdArray Output:

y: Shape { dims: [1, 1] } thread 'main' panicked at /home/goshea/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529:13: ndarray: could not broadcast array from shape: [1, 1, 4, 4] to: [1, 1, 3, 3] note: run withRUST_BACKTRACE=1environment variable to display a backtrace

LibTorch Output:
Finisheddevprofile [unoptimized + debuginfo] target(s) in 9.08s Runningtarget/debug/burn_testy: Shape { dims: [1, 1] } thread 'main' panicked at /home/goshea/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.15.0/src/wrappers/tensor.rs:535:27: calledResult::unwrap()on anErrvalue: Torch("The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 3\nException raised from infer_size_impl at ../aten/src/ATen/ExpandUtils.cpp:31 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7f45930e0a0c in /home/goshea/libtorch/lib/libc10.so)\nframe #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7f459308a8bc in /home/goshea/libtorch/lib/libc10.so)\nframe #2: at::infer_size_dimvector(c10::ArrayRef<long>, c10::ArrayRef<long>) + 0x3d4 (0x7f45949acfa4 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #3: at::TensorIteratorBase::compute_shape(at::TensorIteratorConfig const&) + 0xb8 (0x7f4594a5f3a8 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #4: at::TensorIteratorBase::build(at::TensorIteratorConfig&) + 0x6d (0x7f4594a602ad in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #5: <unknown function> + 0x1cf43ba (0x7f4594e353ba in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #6: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x7a (0x7f4594e36daa in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #7: at::_ops::copy_::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x8f (0x7f4595bf600f in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #8: <unknown function> + 0x5d82a25 (0x7f4598ec3a25 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #9: at::_ops::copy_::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x8f (0x7f4595bf600f in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #10: <unknown function> + 0x5d84b14 (0x7f4598ec5b14 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #11: at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) + 0x16f (0x7f4595c6668f in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #12: <unknown function> + 0x592cd2 (0x558b6da0ecd2 in target/debug/burn_test)\nframe #13: <unknown function> + 0x599f87 (0x558b6da15f87 in target/debug/burn_test)\nframe #14: <unknown function> + 0x57825c (0x558b6d9f425c in target/debug/burn_test)\nframe #15: <unknown function> + 0x578310 (0x558b6d9f4310 in target/debug/burn_test)\nframe #16: <unknown function> + 0x8fcfd (0x558b6d50bcfd in target/debug/burn_test)\nframe #17: <unknown function> + 0x27791d (0x558b6d6f391d in target/debug/burn_test)\nframe #18: <unknown function> + 0x144ea7 (0x558b6d5c0ea7 in target/debug/burn_test)\nframe #19: <unknown function> + 0x1440fe (0x558b6d5c00fe in target/debug/burn_test)\nframe #20: <unknown function> + 0x2757cd (0x558b6d6f17cd in target/debug/burn_test)\nframe #21: <unknown function> + 0x237075 (0x558b6d6b3075 in target/debug/burn_test)\nframe #22: <unknown function> + 0x1d0ff3 (0x558b6d64cff3 in target/debug/burn_test)\nframe #23: <unknown function> + 0x5e737d (0x558b6da6337d in target/debug/burn_test)\nframe #24: <unknown function> + 0x5d87cc (0x558b6da547cc in target/debug/burn_test)\nframe #25: <unknown function> + 0x5e84fb (0x558b6da644fb in target/debug/burn_test)\nframe #26: <unknown function> + 0x5e8578 (0x558b6da64578 in target/debug/burn_test)\nframe #27: <unknown function> + 0x5e7344 (0x558b6da63344 in target/debug/burn_test)\nframe #28: <unknown function> + 0x5d879a (0x558b6da5479a in target/debug/burn_test)\nframe #29: <unknown function> + 0x5e83b1 (0x558b6da643b1 in target/debug/burn_test)\nframe #30: <unknown function> + 0x5e1bc9 (0x558b6da5dbc9 in target/debug/burn_test)\nframe #31: <unknown function> + 0x5e0c98 (0x558b6da5cc98 in target/debug/burn_test)\nframe #32: <unknown function> + 0x5e7279 (0x558b6da63279 in target/debug/burn_test)\nframe #33: <unknown function> + 0x5e6b73 (0x558b6da62b73 in target/debug/burn_test)\nframe #34: <unknown function> + 0x2b3caf (0x558b6d72fcaf in target/debug/burn_test)\nframe #35: <unknown function> + 0x148936 (0x558b6d5c4936 in target/debug/burn_test)\nframe #36: <unknown function> + 0x2ab05c (0x558b6d72705c in target/debug/burn_test)\nframe #37: <unknown function> + 0x1aedc4 (0x558b6d62adc4 in target/debug/burn_test)\nframe #38: <unknown function> + 0x27bfdb (0x558b6d6f7fdb in target/debug/burn_test)\nframe #39: <unknown function> + 0x185e4e (0x558b6d601e4e in target/debug/burn_test)\nframe #40: <unknown function> + 0x181ce1 (0x558b6d5fdce1 in target/debug/burn_test)\nframe #41: <unknown function> + 0x6a3553 (0x558b6db1f553 in target/debug/burn_test)\nframe #42: <unknown function> + 0x181cba (0x558b6d5fdcba in target/debug/burn_test)\nframe #43: <unknown function> + 0x1aeeee (0x558b6d62aeee in target/debug/burn_test)\nframe #44: <unknown function> + 0x2a1ca (0x7f4592d0b1ca in /lib/x86_64-linux-gnu/libc.so.6)\nframe #45: __libc_start_main + 0x8b (0x7f4592d0b28b in /lib/x86_64-linux-gnu/libc.so.6)\nframe #46: <unknown function> + 0x7e705 (0x558b6d4fa705 in target/debug/burn_test)\n") note: run withRUST_BACKTRACE=1 environment variable to display a backtrace

are you using the version of burn from crates.io or the version on main?

Is the model you used hosted somewhere? if not do you mind if I took a look at it?

@skewballfox I'm using version 0.13.2 from crates.io and I've been using some torchvision models such as resnet18, efficientnet_v2_s and mobilenet_v2 each modified for binary classification (see below).

  class Net(nn.Module):
      def __init__(self, model='mn'):
          super(Net, self).__init__()
          if model == 'resnet':
              self.backbone = models.resnet18(weights='IMAGENET1K_V1')
              num_features = self.backbone.fc.in_features
              self.backbone.fc = nn.Linear(num_features, 1) 
          elif model == "efficientnet":
              self.backbone = models.efficientnet_v2_s(weights='IMAGENET1K_V1', )
              num_features = self.backbone.classifier[1].in_features
              self.backbone.classifier[1] = nn.Linear(num_features, 1)
          elif model == "mn":
              self.backbone = models.mobilenet_v2(weights="DEFAULT")
              idx = 1
              num_features = self.backbone.classifier[idx].in_features
              self.backbone.classifier[idx] = nn.Linear(num_features, 1)
  
          self.sigmoid = nn.Sigmoid()
  
      def forward(self, x):
          x = self.backbone(x)
          x = self.sigmoid(x)
          return x

@galenoshea This is a run time error right? Does cargo build runs successfully for the project generating the error?

If it's not too much trouble, and you would rather not send us the onnx file. Could you try recreating with just one of the overloaded models? trying to narrow down the search a bit

Cargo successfully builds and it is a runtime error.

I just tried reproducing and found that Resnet18 works while Mobilenet and efficientnet have issues.

I'm using images of size 224, but I've seen similar issues before when using awkward input sizes. Specifically, this happens when at a given layer when the an odd number of channels are trying to halve (Note the error from above, [1, 1, 4, 4] to [1, 1, 3, 3]) could be in relation to the size of a previous layer (1, 1, 7, 7). In this case, the forward pass works fine for all models, but the backward pass might be encountering a similar issue.

can you navigate to that specific location pointed to in the traceback ( /home/goshea/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529 or whatever location pointed to by the latest version), and tell me the name of the variable passed and what function it's being passed to?

I apologize I'm not sure how to get var info but these are the 2 functions that break when using the 2 backends.

LibTorch Backend
tensor.rs 535

    /// Copies values from the argument tensor to the input tensor.
    pub fn copy_(&mut self, src: &Tensor) {
        self.f_copy_(src).unwrap()
    }

NdArray Backend
lib.rs line 1529

/// Private Methods
impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    #[inline]
    fn broadcast_unwrap<E>(&self, dim: E) -> ArrayView<'_, A, E>
    where
        E: Dimension,
    {
        #[cold]
        #[inline(never)]
        fn broadcast_panic<D, E>(from: &D, to: &E) -> !
        where
            D: Dimension,
            E: Dimension,
        {
            panic!(
                "ndarray: could not broadcast array from shape: {:?} to: {:?}",
                from.slice(),
                to.slice()
            )
        }

        match self.broadcast(dim.clone()) {
            Some(it) => it,
            None => broadcast_panic(&self.dim, &dim),
        }
    }

I apologize I'm not sure how to get var info but these are the 2 functions that break when using the 2 backends.

You're good. I was hoping that traceback pointed to something in the generated model(so then we could figure out what step in burn-import needs work), but that's not the case here.

Sounds like it's a bug in the backward pass of the Burn's OP.

CCing @nathanielsimard , @louisfd , and @laggui, maybe the have some idea.

Cargo successfully builds and it is a runtime error.

I just tried reproducing and found that Resnet18 works while Mobilenet and efficientnet have issues.

I'm using images of size 224, but I've seen similar issues before when using awkward input sizes. Specifically, this happens when at a given layer when the an odd number of channels are trying to halve (Note the error from above, [1, 1, 4, 4] to [1, 1, 3, 3]) could be in relation to the size of a previous layer (1, 1, 7, 7). In this case, the forward pass works fine for all models, but the backward pass might be encountering a similar issue.

Could you share one of the ONNX models so we can try to reproduce this issue?

@laggui doesn't allow me to drop here, where do I share the model?

We don’t support that file type.

Try again with GIF, JPEG, JPG, MOV, MP4, PNG, SVG, WEBM, CPUPROFILE, CSV, DMP, DOCX, FODG, FODP, FODS, FODT, GZ, JSON, JSONC, LOG, MD, ODF, ODG, ODP, ODS, ODT, PATCH, PDF, PPTX, TGZ, TXT, XLS, XLSX or ZIP.

You could upload it somewhere (e.g., google drive) and share the link.

Or, if it's a torchvision model you could share the script you used to generate the ONNX model with pytorch.

/edit: ah right as pointed out below github supports zip format so you can zip the onnx file to upload it here.

You need to zip it

mobilenet_v2.zip

here's the zip and you'll find the func for creating the model above

I can reproduce the issue with the provided onnx model on both ndarray and torch backends.

ndarray: could not broadcast array from shape: [1, 1, 4, 4] to: [1, 1, 3, 3]

The issue happens in the conv2d backward with groups, specifically this line.

Found the bug! Thanks a lot for filing the issue. Fixed with PR #1891.