rstudio / tensorflow

TensorFlow for R

Home Page:https://tensorflow.rstudio.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot seem to replicate tf$keras$optimizer interface in R tensorflow

njtierney opened this issue · comments

Hi there!

I'm trying to follow examples from https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD - as we're trying to get something to work in TF in greta, and are running into the same error

Trying to translate this python code into TF code below

opt = tf.keras.optimizers.SGD(learning_rate=0.1)
var = tf.Variable(1.0)
loss = lambda: (var ** 2)/2.0         # d(loss)/d(var1) = var1
step_count = opt.minimize(loss, [var]).numpy()
# Step is `- learning_rate * grad`
var.numpy()

And now as R code:

library(tensorflow)
opt <- tf$keras$optimizers$SGD(learning_rate=0.1)
#> Loaded Tensorflow version 2.10.0
var <- tf$Variable(1.0)
loss <- function(var) (var ** 2)/2.0         # d(loss)/d(var1) = var1
step_count <- opt$minimize(loss, var)$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
step_count <- opt$minimize(loss, list(var = var))$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
step_count <- opt$minimize(loss, list(var))$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.
step_count <- opt$minimize(loss, c(var))$numpy()
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: argument "var" is missing, with no default.

Created on 2022-10-24 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23)
#>  os       macOS Monterey 12.3.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-10-24
#>  pandoc   2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version date (UTC) lib source
#>  base64enc     0.1-3   2015-07-28 [1] CRAN (R 4.2.0)
#>  cli           3.4.1   2022-09-23 [1] CRAN (R 4.2.0)
#>  digest        0.6.30  2022-10-18 [1] CRAN (R 4.2.0)
#>  evaluate      0.17    2022-10-07 [1] CRAN (R 4.2.0)
#>  fansi         1.0.3   2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap       1.1.0   2021-01-25 [1] CRAN (R 4.2.0)
#>  fs            1.5.2   2021-12-08 [1] CRAN (R 4.2.0)
#>  glue          1.6.2   2022-02-24 [1] CRAN (R 4.2.0)
#>  here          1.0.1   2020-12-13 [1] CRAN (R 4.2.0)
#>  highr         0.9     2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools     0.5.3   2022-07-18 [1] CRAN (R 4.2.0)
#>  jsonlite      1.8.2   2022-10-02 [1] CRAN (R 4.2.0)
#>  knitr         1.40    2022-08-24 [1] CRAN (R 4.2.0)
#>  lattice       0.20-45 2021-09-22 [1] CRAN (R 4.2.1)
#>  lifecycle     1.0.3   2022-10-07 [1] CRAN (R 4.2.0)
#>  magrittr      2.0.3   2022-03-30 [1] CRAN (R 4.2.0)
#>  Matrix        1.5-1   2022-09-13 [1] CRAN (R 4.2.0)
#>  pillar        1.8.1   2022-08-19 [1] CRAN (R 4.2.0)
#>  pkgconfig     2.0.3   2019-09-22 [1] CRAN (R 4.2.0)
#>  png           0.1-7   2013-12-03 [1] CRAN (R 4.2.0)
#>  purrr         0.3.5   2022-10-06 [1] CRAN (R 4.2.0)
#>  R.cache       0.16.0  2022-07-21 [1] CRAN (R 4.2.0)
#>  R.methodsS3   1.8.2   2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo          1.25.0  2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils       2.12.0  2022-06-28 [1] CRAN (R 4.2.0)
#>  Rcpp          1.0.9   2022-07-08 [1] CRAN (R 4.2.0)
#>  reprex        2.0.2   2022-08-17 [1] CRAN (R 4.2.0)
#>  reticulate    1.26    2022-08-31 [1] CRAN (R 4.2.0)
#>  rlang         1.0.6   2022-09-24 [1] CRAN (R 4.2.0)
#>  rmarkdown     2.17    2022-10-07 [1] CRAN (R 4.2.0)
#>  rprojroot     2.0.3   2022-04-02 [1] CRAN (R 4.2.0)
#>  rstudioapi    0.14    2022-08-22 [1] CRAN (R 4.2.0)
#>  sessioninfo   1.2.2   2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi       1.7.8   2022-07-11 [1] CRAN (R 4.2.0)
#>  stringr       1.4.1   2022-08-20 [1] CRAN (R 4.2.0)
#>  styler        1.7.0   2022-03-13 [1] CRAN (R 4.2.0)
#>  tensorflow  * 2.9.0   2022-05-21 [1] CRAN (R 4.2.0)
#>  tfruns        1.5.1   2022-09-05 [1] CRAN (R 4.2.0)
#>  tibble        3.1.8   2022-07-22 [1] CRAN (R 4.2.0)
#>  utf8          1.2.2   2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs         0.4.2   2022-09-29 [1] CRAN (R 4.2.0)
#>  whisker       0.4     2019-08-28 [1] CRAN (R 4.2.0)
#>  withr         2.5.0   2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun          0.33    2022-09-12 [1] CRAN (R 4.2.0)
#>  yaml          2.3.5   2022-02-21 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate:/Users/nick/Library/r-miniconda-arm64/envs/r-reticulate
#>  version:        3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16)  [Clang 12.0.1 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/r-reticulate/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.23.2
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Happy to help 😄

library(tensorflow)

opt <- tf$keras$optimizers$SGD(learning_rate = 0.1)
var <- tf$Variable(1.0)
loss <- function() (var^2)/2
step_count <- opt$minimize(loss, list(var))$numpy
var$numpy()
#> [1] 0.9

Created on 2022-10-24 with reprex v2.0.2

The opt$minimize() method expects a loss function that takes no arguments.
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#minimize

If a callable, loss should take no arguments and return the value to minimize

Thank you SO much! That was doing my head in. That brings us one step closer to solving this problem we are working with :)