mlverse / luz

Higher Level API for torch

Home Page:https://mlverse.github.io/luz/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

rerun the same 'setup' codes a second time results in error

Backlonhw2468 opened this issue · comments

Rerunning the 'setup' part of codes in Chapter 16.2.1 of book Deep-Learning-and-Scientific-Computing-with-R-torch a second time would lead to error saying "element 0 of tensors does not require grad and does not have a grad_fn". Please help.

`convnet <- nn_module(
"convnet",
initialize = function() {
# nn_conv2d(in_channels, out_channels, kernel_size, stride)
self$conv1 <- nn_conv2d(1, 32, 3, 1)
self$conv2 <- nn_conv2d(32, 64, 3, 2)
self$conv3 <- nn_conv2d(64, 128, 3, 1)
self$conv4 <- nn_conv2d(128, 256, 3, 2)
self$conv5 <- nn_conv2d(256, 10, 3, 2)
},
forward = function(x) {
x %>%
self$conv1() %>%
nnf_relu() %>%
self$conv2() %>%
nnf_relu() %>%
self$conv3() %>%
nnf_relu() %>%
self$conv4() %>%
nnf_relu() %>%
self$conv5() %>%
torch_squeeze()
}
)

fitted <- convnet %>%
setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_adam,
metrics = list(
luz_metric_accuracy()
)
) %>%
fit(train_dl, epochs = 5, valid_data = valid_dl)`

It occurs if the fitting process was mannually stopped during an epoch.

Thanks for reporting!

Which version of luz are you running?

luz version 0.4.0;
torch version 0.12.0.

Same problem would occur when running codes in Chapter 14.2.3, as long as I manually stop the fitting process and rerun the same codes again.
I guess it might have something to do with the problem of memory leaking using 'fit'. Not sure though.
Currently the only way I could get the codes running is to close the whole RStudio session and restart again, which is fine when dealing with relatively small datasets. When datasets at hand gets big, such solution would be problematic since data preprocessing before feeding it to network training would be time consuming.

This is probably a bug with how we reset the grad mode after an interrupt.
A workaround might be to run: torch::autograd_set_grad_mode(TRUE) after the interrupt.
I wasn't able to reproduce yet. Can you post your full sessionInfo()?

torch::autograd_set_grad_mode(TRUE) does solve the problem. Thanks!

Session Info:

R version 4.1.3 (2022-03-10)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server x64 (build 14393)

Matrix products: default

attached base packages:
[1] stats graphics grDevices utils datasets methods base

other attached packages:
[1] torchvision_0.5.1 luz_0.4.0 torch_0.12.0

loaded via a namespace (and not attached):
[1] Rcpp_1.0.9 rstudioapi_0.13 magrittr_2.0.3 hms_1.1.1 progress_1.2.2 bit_4.0.4 R6_2.5.1 rlang_1.1.3 tools_4.1.3 coro_1.0.3 cli_3.6.2 withr_2.5.0
[13] ellipsis_0.3.2 remotes_2.5.0 bit64_4.0.5 lifecycle_1.0.4 crayon_1.5.1 processx_3.8.3 purrr_0.3.4 callr_3.7.3 vctrs_0.6.1 fs_1.5.2 ps_1.7.5 zeallot_0.1.0
[25] glue_1.6.2 compiler_4.1.3 generics_0.1.3 prettyunits_1.1.1 pkgconfig_2.0.3