haanjack / mnist-cudnn

CUDA for MNIST training/inference

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Did not use all the data(picture) during training or inferencing

linlll opened this issue · comments

commented

Hi,
Thanks very much for your project, and I learnt a lot from it. But it seems that your project did not use all the data(picture) during training or inferencing. I modified the code a little to keep track of the index of all the pictures used:

  1. I defined a public variable in class MNIST: public: std::vector<int> idx_store;

  2. I stored all the indexes of picture used: (function void MNIST::get_batch())

    for (int i = 0; i < batch_size_; i++) {
        std::copy(data_pool_[data_idx + i].data(),
            &data_pool_[data_idx + i].data()[data_size],
            &data_->ptr()[data_size * i]);
        idx_store.push_back(data_idx + i); // added
    }
  3. After the training, I tried to find an index greater than 500, but didn't work:

    while (step < num_steps_train) {
        /* training... */
    }
    
    auto idx = train_data_loader.idx_store;
    for (int i = 0; i < idx.size(); i++) {
        if (idx[i] > 500) {
            std::cout << idx[i] << std::endl; // debug here
        }
    }

Then I figured out what the problem was: the following code in function void MNIST::get_batch()

int data_idx = (step_ * batch_size_) % num_steps_

This code limits the range of data_idx to between 0 and num_steps_, but it should be between 0 and 60000 (10000, test), so it only needs to be modified this way

int data_idx = step_ % num_steps_ * batch_size_;

After this modification, here is the result running on my machine:

[INFERENCE]
loading ./dataset/t10k-images.idx3-ubyte
loaded 10000 items..
conv1: Available Algorithm Count [FWD]: 10
conv1: Available Algorithm Count [BWD-filter]: 9
conv1: Available Algorithm Count [BWD-data]: 8
conv2: Available Algorithm Count [FWD]: 10
conv2: Available Algorithm Count [BWD-filter]: 9
conv2: Available Algorithm Count [BWD-data]: 8
loss: 0.145, accuracy: 90.050%
Done.

Thanks for your detail review and contribution.

I close this issue as I merged your PR :)