mlverse / torchvision

R interface to torchvision

Home Page:https://torchvision.mlverse.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using prebuilt/-trained models with luz

jemus42 opened this issue · comments

I wanted to use a prebuilt architecture with luz, either pretrained or "fresh".
In luz::setup() I ran into an error which I followed to this line checking for a forward() method in the module.

library(torch)
library(torchvision)
library(luz)

# Getting a pretrained model
model <- model_alexnet(pretrained = FALSE, num_classes = 10)

model %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam
  )
#> Error in x$get_inherit(): attempt to apply non-function

I was then trying to figure out the difference between the pre-made models and self-defined torch::nn_module() models which I had used successfully with luz before, so I extracted the code for AlexNet and tried it manually, which works fine:

# Defining alexnet manually:
# https://github.com/mlverse/torchvision/blob/main/R/models-alexnet.R#L2-L37
# Should be identical to calling torchvision:::alexnet
alexnet <- torch::nn_module(
  "AlexNet",
  initialize = function(num_classes = 1000) {
    self$features <- torch::nn_sequential(
      torch::nn_conv2d(3, 64, kernel_size = 11, stride = 4, padding = 2),
      torch::nn_relu(inplace = TRUE),
      torch::nn_max_pool2d(kernel_size = 3, stride = 2),
      torch::nn_conv2d(64, 192, kernel_size = 5, padding = 2),
      torch::nn_relu(inplace = TRUE),
      torch::nn_max_pool2d(kernel_size = 3, stride = 2),
      torch::nn_conv2d(192, 384, kernel_size = 3, padding = 1),
      torch::nn_relu(inplace = TRUE),
      torch::nn_conv2d(384, 256, kernel_size = 3, padding = 1),
      torch::nn_relu(inplace = TRUE),
      torch::nn_conv2d(256, 256, kernel_size = 3, padding = 1),
      torch::nn_relu(inplace = TRUE),
      torch::nn_max_pool2d(kernel_size = 3, stride = 2)
    )
    self$avgpool <- torch::nn_adaptive_avg_pool2d(c(6,6))
    self$classifier <- torch::nn_sequential(
      torch::nn_dropout(),
      torch::nn_linear(256 * 6 * 6, 4096),
      torch::nn_relu(inplace = TRUE),
      torch::nn_dropout(),
      torch::nn_linear(4096, 4096),
      torch::nn_relu(inplace = TRUE),
      torch::nn_linear(4096, num_classes)
    )
  },
  forward = function(x) {
    x <- self$features(x)
    x <- self$avgpool(x)
    x <- torch_flatten(x, start_dim = 2)
    x <- self$classifier(x)
  }
)

alexnet %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam
  )
#> <luz_module_generator>

The only difference I can point to is the classes of each model:

# Why is model_alexnet() missing a class?
class(alexnet)
#> [1] "AlexNet"             "nn_module"           "nn_module_generator"
class(torchvision:::alexnet)
#> [1] "AlexNet"             "nn_module"           "nn_module_generator"
class(torchvision::model_alexnet(pretrained = FALSE))
#> [1] "AlexNet"   "nn_module"
class(torchvision::model_alexnet(pretrained = TRUE))
#> [1] "AlexNet"   "nn_module"

Given the code for model_alexnet here I don't understand why model_alexnet(pretrained = FALSE) and torchvision:::alexnet should differ at all, as the function just returns torchvision:::alexnet if pretrained = FALSE.

I am not sure if this an issue with the model setup in torchvision or with luz::setup, but I thought I'd start here.

Session info
# Session info ------------------------------------------------------------
sessioninfo::session_info(pkgs = c("torch", "torchvision", "luz"))
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       Ubuntu 20.04.3 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/Berlin
#>  date     2022-02-07
#>  pandoc   2.14.0.3 @ /usr/lib/rstudio-server/bin/pandoc/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version date (UTC) lib source
#>  abind         1.4-5   2016-07-21 [1] RSPM (R 4.1.0)
#>  bit           4.0.4   2020-08-04 [1] RSPM (R 4.1.0)
#>  bit64         4.0.5   2020-08-30 [1] RSPM (R 4.1.0)
#>  callr         3.7.0   2021-04-20 [1] RSPM (R 4.1.0)
#>  cli           3.1.1   2022-01-20 [1] RSPM (R 4.1.0)
#>  coro          1.0.2   2021-12-03 [1] RSPM (R 4.1.0)
#>  crayon        1.4.2   2021-10-29 [1] RSPM (R 4.1.0)
#>  ellipsis      0.3.2   2021-04-29 [1] RSPM (R 4.1.0)
#>  fs            1.5.2   2021-12-08 [1] RSPM (R 4.1.0)
#>  generics      0.1.2   2022-01-31 [1] RSPM (R 4.1.0)
#>  glue          1.6.1   2022-01-22 [1] CRAN (R 4.1.1)
#>  hms           1.1.1   2021-09-26 [1] RSPM (R 4.1.0)
#>  jpeg          0.1-9   2021-07-24 [1] CRAN (R 4.1.1)
#>  lifecycle     1.0.1   2021-09-24 [1] RSPM (R 4.1.0)
#>  luz         * 0.2.0   2021-10-07 [1] RSPM (R 4.1.0)
#>  magrittr      2.0.2   2022-01-26 [1] RSPM (R 4.1.0)
#>  pkgconfig     2.0.3   2019-09-22 [1] RSPM (R 4.1.0)
#>  png           0.1-7   2013-12-03 [1] RSPM (R 4.1.0)
#>  prettyunits   1.1.1   2020-01-24 [1] RSPM (R 4.1.0)
#>  processx      3.5.2   2021-04-30 [1] RSPM (R 4.1.0)
#>  progress      1.2.2   2019-05-16 [1] RSPM (R 4.1.0)
#>  ps            1.6.0   2021-02-28 [1] RSPM (R 4.1.0)
#>  purrr         0.3.4   2020-04-17 [1] RSPM (R 4.1.0)
#>  R6            2.5.1   2021-08-19 [1] CRAN (R 4.1.1)
#>  rappdirs      0.3.3   2021-01-31 [1] RSPM (R 4.1.0)
#>  Rcpp          1.0.8   2022-01-13 [1] CRAN (R 4.1.1)
#>  rlang         1.0.1   2022-02-03 [1] CRAN (R 4.1.2)
#>  torch       * 0.6.0   2021-10-07 [1] RSPM (R 4.1.0)
#>  torchvision * 0.4.1   2022-01-28 [1] RSPM (R 4.1.0)
#>  vctrs         0.3.8   2021-04-29 [1] RSPM (R 4.1.0)
#>  withr         2.4.3   2021-11-30 [1] CRAN (R 4.1.1)
#>  zeallot       0.1.0   2018-01-28 [1] RSPM (R 4.1.0)
#> 
#>  [1] /home/burk/R/x86_64-pc-linux-gnu-library/4.1
#>  [2] /opt/R/4.1.2/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
Created on 2022-02-07 by the [reprex package](https://reprex.tidyverse.org) (v2.0.1)

You can find an example of using a pre-trained model with luz here: https://mlverse.github.io/luz/articles/examples/dogs-vs-casts-binary-classification.html

net <- torch::[nn_module](https://torch.mlverse.org/docs/reference/nn_module.html)(
  initialize = function(num_classes) {
    self$model <- model_alexnet(pretrained = TRUE)
  },
  forward = function(x) {
    self$model(x)[,1]
  }
)

Luz setup, fit, etc only work with module generators ie. the object returned by torch::nn_module (those objects when called like torch::nn_linear(10, 10) generate a nn_module). This is to make sure luz will always own the initialization of its module thus it can safely use in-place operations, like move parameters to a specific device, modify parameter values etc without affecting the global environment.

While we could make luz work with nn_modules (as opposed to nn_module_generators) to make code less verbose, this is somehow, in general when using pre-trained models like those from torchvision you will need to modify the model head, freeze parameters and etc and thus needing a custom torch::nn_module anyway.

Ah, thanks a lot, that clears it up.

For using a pretrained model I had so far resorted to doing something like

  for (par in model$parameters) {
    par$requires_grad_(FALSE)
  }

  model$classifier$`6` <- torch::nn_linear(
    in_features = model$classifier$`6`$in_features,
    out_features = 10
  )

but I see how wrapping a pretrained model in a new nn_module is preferred.