mlverse / torch

R Interface to Torch

Home Page:https://torch.mlverse.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

dataset subset introducing NAs by coercion

dominicwhite opened this issue · comments

When I use the dataset_subset() function, I get the following error when later training the model with the fit() function:

Warning: NAs introduced by coercion to integer rangeError in (function (self, target, weight, reduction, ignore_index, label_smoothing) :
Evaluation error: missing replacement values are not allowed.

Reproducible example:

library(torch)
library(torchvision)
library(luz)

train_ds <- kmnist_dataset(
  "imagesk", 
  download = TRUE,
  transform = . %>%
    transform_to_tensor() %>%
    torch_flatten()
  )

train_ds <- dataset_subset(train_ds, indices=1:1000)
valid_ds <- dataset_subset(train_ds, indices=1001:1500)

train_dl <- dataloader(
  train_ds, 
  batch_size = 32,
  shuffle = TRUE
  )

valid_dl <- dataloader(
  valid_ds,
  batch_size = 32
  )

net <- nn_module(
  "onelayer",
  initialize = function() {
    self$net <- nn_sequential(
      nn_linear(784,128),
      nn_relu(),
      nn_linear(128,10)
    )
  },
  forward = function(x) {
    self$net(x)
  }
)

model1 <- net %>%
  setup(
    loss = nn_cross_entropy_loss(), 
    optimizer = optim_adam, 
    metrics = list(
      luz_metric_accuracy()
    )
  )

fitted1 <- fit(
  model1,
  train_dl,
  epochs = 2,
  valid_data = valid_dl,
  verbose = TRUE
)

I have found three separate "solutions" that each seem to that fix this issue and allow the model to train without that error:

  • Removing the dataset_subset() lines.
  • Removing the valid_data = valid_dl argument from the fit() function.
  • Changing the index ranges for the train and validation sets so that the validation set starts with lower indices, for example:
train_ds <- dataset_subset(train_ds, indices=1001:2024)
valid_ds <- dataset_subset(train_ds, indices=1:512)

However, I'm not sure why the original code shouldn't work? Why would switching the subset indices (my third solution) fix this?

Sorry @dominicwhite for taking so long to look at this issue.
I tried running your reproducible example using the dev version of torch and luz and could not reproduce.
I feel like this could be related to something like #961