Cannot seem to replicate tf$keras$optimizer interface in R tensorflow
njtierney opened this issue · comments
Hi there!
I'm trying to follow examples from https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD - as we're trying to get something to work in TF in greta, and are running into the same error
Trying to translate this python code into TF code below
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
var = tf.Variable(1.0)
loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1
step_count = opt.minimize(loss, [var]).numpy()
# Step is `- learning_rate * grad`
var.numpy()
And now as R code:
library(tensorflow)
opt <- tf$keras$optimizers$SGD(learning_rate=0.1)
#> Loaded Tensorflow version 2.10.0
var <- tf$Variable(1.0)
loss <- function(var) (var ** 2)/2.0 # d(loss)/d(var1) = var1
step_count <- opt$minimize(loss, var)$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
step_count <- opt$minimize(loss, list(var = var))$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
step_count <- opt$minimize(loss, list(var))$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
step_count <- opt$minimize(loss, c(var))$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
Created on 2022-10-24 with reprex v2.0.2
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.2.1 (2022-06-23)
#> os macOS Monterey 12.3.1
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_AU.UTF-8
#> ctype en_AU.UTF-8
#> tz Australia/Perth
#> date 2022-10-24
#> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0)
#> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0)
#> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0)
#> evaluate 0.17 2022-10-07 [1] CRAN (R 4.2.0)
#> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0)
#> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0)
#> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0)
#> jsonlite 1.8.2 2022-10-02 [1] CRAN (R 4.2.0)
#> knitr 1.40 2022-08-24 [1] CRAN (R 4.2.0)
#> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1)
#> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0)
#> Matrix 1.5-1 2022-09-13 [1] CRAN (R 4.2.0)
#> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0)
#> png 0.1-7 2013-12-03 [1] CRAN (R 4.2.0)
#> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0)
#> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0)
#> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0)
#> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0)
#> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0)
#> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0)
#> reticulate 1.26 2022-08-31 [1] CRAN (R 4.2.0)
#> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0)
#> rmarkdown 2.17 2022-10-07 [1] CRAN (R 4.2.0)
#> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0)
#> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0)
#> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0)
#> stringr 1.4.1 2022-08-20 [1] CRAN (R 4.2.0)
#> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0)
#> tensorflow * 2.9.0 2022-05-21 [1] CRAN (R 4.2.0)
#> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0)
#> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0)
#> vctrs 0.4.2 2022-09-29 [1] CRAN (R 4.2.0)
#> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0)
#> xfun 0.33 2022-09-12 [1] CRAN (R 4.2.0)
#> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#>
#> ─ Python configuration ───────────────────────────────────────────────────────
#> python: /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate/bin/python
#> libpython: /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate/lib/libpython3.8.dylib
#> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate:/Users/nick/Library/r-miniconda-arm64/envs/r-reticulate
#> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ]
#> numpy: /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate/lib/python3.8/site-packages/numpy
#> numpy_version: 1.23.2
#>
#> ──────────────────────────────────────────────────────────────────────────────
Happy to help 😄
library(tensorflow)
opt <- tf$keras$optimizers$SGD(learning_rate = 0.1)
var <- tf$Variable(1.0)
loss <- function() (var^2)/2
step_count <- opt$minimize(loss, list(var))$numpy
var$numpy()
#> [1] 0.9
Created on 2022-10-24 with reprex v2.0.2
The opt$minimize()
method expects a loss
function that takes no arguments.
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#minimize
If a callable, loss should take no arguments and return the value to minimize
Thank you SO much! That was doing my head in. That brings us one step closer to solving this problem we are working with :)