mlverse / luz

Higher Level API for torch

Home Page:https://mlverse.github.io/luz/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error in setup() with nn_sequential()

skeydan opened this issue · comments

This works:

net <- nn_module(
  initialize = function(d_in, d_hidden, d_out) {
    self$linear1 <- nn_linear(d_in, d_hidden)
    self$relu <- nn_relu()
    self$linear2 <- nn_linear(d_hidden, d_out)                   
  },
  forward = function(x) {
    x %>%
      self$linear1() %>%
      self$relu() %>%
      self$linear2()
  }
)

net %>%
  setup(loss = nn_mse_loss(), optimizer = optim_adam)

This throws an error:

net <- nn_sequential(
  nn_linear(d_in, d_hidden),
  nn_relu(),
  nn_linear(d_hidden, d_out)
)

net %>%
  setup(loss = nn_mse_loss(), optimizer = optim_adam)

Error in x$get_inherit() : attempt to apply non-function 
4.
get_forward(x) at utils.R#46
3.
has_forward_method(module) at module.R#48
2.
setup(., loss = nn_mse_loss(), optimizer = optim_adam) 
1.
net %>% setup(loss = nn_mse_loss(), optimizer = optim_adam) 

This is expected because luz requires modules that are not yet initialized and nn_sequential return initialized modules.
You can do something like:

net <- nn_module(
  initialize = function() {
    self$net <- nn_sequential(
      nn_linear(d_in, d_hidden),
      nn_relu(),
      nn_linear(d_hidden, d_out)
    )
  },
  forward = function(x) { 
    self$net(x)
  }
}

That makes total sense!

For posterity adding another situation where this error occurs.

Consider an nn_module which has a hyper parameter num_classes, so it has the line initialize = function(num_classes}{ ... }.. in it.

then,

model <- net(num_classes = 10) |> 
     setup(loss = nn_cross_entropy_loss(), optimizer = optim_adam)

throws the error,

Error in x$get_inherit() : attempt to apply non-function

whereas,

model <- net |> 
     setup(loss = nn_cross_entropy_loss(), optimizer = optim_adam) |> 
     set_hparams(num_classes = 10)

works.