tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.

Home Page:https://burn.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The `model.load_record()` method turns off the activation function during the forward pass

syunar opened this issue · comments

Describe the bug
The model.load_record() method turns off the activation function during the forward pass.

To Reproduce

  1. Initialize the model with the ReLU activation function.
  2. Perform a forward pass with input data before loading weights.
  3. Load the model weights using model.load_record(record).
  4. Perform a forward pass with input data after loading weights.
  5. Observe that the activation function is not applied as expected after loading the weights.
use burn::{
    backend::NdArray, 
    module::Module, 
    nn::{
        conv::{Conv2d, Conv2dConfig}, 
        BatchNorm, BatchNormConfig, PaddingConfig2d, Relu
    }, 
    tensor::{backend::Backend, Device, Tensor, Distribution, Shape},
    record::{FullPrecisionSettings, Recorder}
};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};


#[derive(Module, Debug)]
pub struct RustCustomModel<B: Backend> {
    conv: Conv2d<B>,
    bn: BatchNorm<B, 2>,
    relu: Option<Relu>,
    activation: bool
}

impl<B: Backend> RustCustomModel<B> {
    pub fn new(activation: bool, device: &Device<B>) -> Self {

        let conv: Conv2d<B> = Conv2dConfig::new([3, 64], [1, 1])
            .with_stride([1, 1])
            .with_padding(PaddingConfig2d::Explicit(1, 1))
            .with_bias(false)
            .init(device);
        let bn: BatchNorm<B, 2> = BatchNormConfig::new(64).init(device);

        let relu: Option<Relu> = if activation {Some(Relu::new())} else {None};
        println!("init relu: {:?}", relu);
        Self {
            conv,
            bn,
            relu,
            activation
        }
    }

    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {

        let out: Tensor<B, 4> = self.conv.forward(input);
        let out: Tensor<B, 4> = self.bn.forward(out);
        println!("self.relu: {:?}", self.relu);
        if let Some(relu) = &self.relu{
            println!("with activation");
            relu.forward(out)
            } else {
            println!("without activation");
            out
        }
    }
}

fn main() {
    let device = Default::default();
    
    let model = RustCustomModel::<NdArray>::new(true, &device);
    
    let input_shape = Shape::new([1, 64, 56, 56]);
    let input_tensor = Tensor::<NdArray, 4>::random(input_shape, Distribution::Default, &device);
    
    println!("\n### before load weight ### ");
    println!("input shape: {:?}", input_tensor.shape());
    let output_tensor = model.forward(input_tensor.clone());
    println!("output shape: {:?}", output_tensor.shape());
    
    println!("\n### after load weight ###");
    let load_args = LoadArgs::new("model.pt".into());
    let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
    .load(load_args, &device)
    .expect("Should load PyTorch model weights");
    let model = model.load_record(record);

    println!("input shape: {:?}", input_tensor.shape());
    let output_tensor = model.forward(input_tensor.clone());
    println!("output shape: {:?}", output_tensor.shape());

}

Expected behavior
The activation function (ReLU) should be correctly applied during the forward pass both before and after loading the model weights.

Actual Behavior:
Before loading the weights, the forward pass correctly applies the ReLU activation function. After loading the weights, the ReLU activation function is set to None, resulting in the forward pass running without the activation function.

// output log
init relu: Some(Relu)

### before load weight ### 
input shape: Shape { dims: [1, 64, 56, 56] }
+ self.relu: Some(Relu)
+ with activation
output shape: Shape { dims: [1, 64, 58, 58] }

### after load weight ###
input shape: Shape { dims: [1, 64, 56, 56] }
- self.relu: None
- without activation
output shape: Shape { dims: [1, 64, 58, 58] }

Desktop (please complete the following information):

  • OS: ubuntu 24.04
  • rustc 1.78.0 (9b00956e5 2024-04-29)
[dependencies]
burn = { version = "0.13.2", features = ["ndarray"] }
burn-import = { version = "0.13.2" }

Hmm well ReLU has no parameters so it's not saved with the weights, which explains why it's initialized to the default value None when loading the state from the saved record. For an optional layer with parameters it will work. In the meantime, you can manually set the activation yourself after loading the weights.

@nathanielsimard how do we want to handle this? 🤔

By trying to reproduce the bug, it seems like you are attempting to load a record where the relu module is saved as None. Even if relu is a constant module, Some(ConstantModule) is supposed to be present in the saved file; otherwise, relu would not be included. The question is: do we want to handle optional constants differently? Maybe. I made a PR to illustrate the behavior.