Some questions about training and testing shift model
ruizewang opened this issue · comments
Hello, I have some questions when training and testing, which make me bothered.
- The parameter you used for training shift model is the default parameter of the code you provied?
- You said: "In my experiments, the model took something like 2K iterations to reach chance performance (loss:label = 0.693), and 11K iterations to do better than chance (loss:label = 0.692). So, for a long time it looked like the model was stuck at chance."
So my question is, after the model can do better than chance, and the "loss:label" will decrease faster? - when do testing, I mean the testing you calculate the accuracy on the test set as mentioned in your paper. In your paper, you mentioned We found that the model obtained 59.9% accuracy on held-out videos for its alignment task (chance = 50%). My question is, the parameter "do_shift" should set to False or True? When I set it to True, the accuracy is 0.50633484. Set to False, I got an accuracy of 0.43133482. Both are quite different from the 0.599 reported in your paper. By the way, I use the same code reading dataset, I use the pre-trained model you provided. The dataset is generated from AudioSet.
Here is the code for testing, I only add a function in "class NetClf" in the "shift_net.py".
def test_accuracy(self, reset=True):
gpus = mu.set_gpus(self.gpu)
print('Loading Model')
if self.sess is None:
print 'Running on:', gpus
with tf.device(gpus[0]):
if reset:
tf.reset_default_graph()
tf.Graph().as_default()
pr = self.pr
pr_test = pr.copy()
self.sess = tf.Session()
pr_test.augment_ims = False
print 'pr_test ='
print pr_test
print('loading dataset...')
with tf.device('/cpu:0'):
rec_files = shift_dset.rec_files_from_path(pr_test.test_list)
total_examples = len(rec_files)*8841
total_batch = int(total_examples/pr.test_batch)
print('the number of total examples:',total_examples)
print('the number of total batch:',total_batch)
self.test_ims, self.test_samples = mu.on_cpu(
lambda: shift_dset.make_db_reader(
pr_test.test_list, pr_test, pr.test_batch, ['im', 'samples'], one_pass=True))
print 'sample shape:', shape(self.test_samples) # [10, 87587, 2]
if pr_test.do_shift:
print('do shifting...')
self.test_labels = tf.random_uniform([shape(self.test_ims, 0)], 0, 2, dtype=tf.int64)
self.test_samples = tf.where(tf.equal(self.test_labels, 1), self.test_samples[:, 1],
self.test_samples[:, 0])
else:
self.test_labels = tf.ones(shape(self.test_ims, 0), dtype=tf.int64)
# self.test_samples = tf.where(tf.equal(self.test_labels, 1), self.test_samples[:, 1], self.test_samples[:, 0])
print('make net')
self.test_net = make_net(self.test_ims, self.test_samples, pr_test, reuse=False, train=False)
self.coord = tf.train.Coordinator()
self.init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
self.sess.run(self.init_op)
tf.train.Saver().restore(self.sess, self.model_path)
tf.get_default_graph().finalize()
print('Start testing...')
tf.train.start_queue_runners(self.sess, coord=self.coord)
self.total_acc = []
i = 0
try:
while not self.coord.should_stop():
start = ut.now_sec()
predict_logits = self.sess.run(self.test_net.logits)
predict_logits = np.squeeze(predict_logits)
predict_labels = np.array(predict_logits > 0).astype(np.int64)
labels = self.sess.run(self.test_labels)
correct_list = (predict_labels == labels)
acc = np.mean(np.array(correct_list).astype(np.float32))
self.total_acc.append(acc)
i += 1
print 'Iter: %d/%d, Accuracy: %s, time: %.3f' % (i, total_batch, acc,ut.now_sec() - start)
except tf.errors.OutOfRangeError:
print('Test Done!')
return np.mean(np.array(self.total_acc))
And do testing like this:
import shift_net, shift_params, numpy as np
import time
pr = shift_params.shift_v1()
model_file = '../results/nets/shift/net.tf-650000'
gpu = '3'
start_time = time.time()
clf = shift_net.NetClf(pr, model_file, gpu=gpu)
accuray = clf.test_accuracy()
end_time = time.time()
print('pr.test_list:',pr.test_list)
print('model_file:',model_file)
print('pr.do_shift:',pr.do_shift)
print('accuray:',accuray)
print('cost time: {} s'.format(end_time-start_time))