mit-han-lab / temporal-shift-module

[ICCV 2019] TSM: Temporal Shift Module for Efficient Video Understanding

Home Page:https://arxiv.org/abs/1811.08383

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

训练ucf101时__getitem__函数得返回值跟dataloader的返回值不一致

dengfenglai321 opened this issue · comments

hi, 我训练ucf101时发现:
训练ucf101时__getitem__函数得返回值跟dataloader的返回值不一致

  1. 打印__getitem__函数的返回值如下:
    ` def getitem(self, index):
    record = self.video_list[index]
    # check this is a legit video folder

     if self.image_tmpl == 'flow_{}_{:05d}.jpg':
         file_name = self.image_tmpl.format('x', 1)
         full_path = os.path.join(self.root_path, record.path, file_name)
     elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
         file_name = self.image_tmpl.format(int(record.path), 'x', 1)
         full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
     else:
         file_name = self.image_tmpl.format(1)
         full_path = os.path.join(self.root_path, record.path, file_name)
    
     while not os.path.exists(full_path):
         print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
         index = np.random.randint(len(self.video_list))
         record = self.video_list[index]
         if self.image_tmpl == 'flow_{}_{:05d}.jpg':
             file_name = self.image_tmpl.format('x', 1)
             full_path = os.path.join(self.root_path, record.path, file_name)
         elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
             file_name = self.image_tmpl.format(int(record.path), 'x', 1)
             full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
         else:
             file_name = self.image_tmpl.format(1)
             full_path = os.path.join(self.root_path, record.path, file_name)
    
     if not self.test_mode:
         segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
     else:
         segment_indices = self._get_test_indices(record)
     # print('record : {}'.format(record))
     print('segment_indices : {}'.format(segment_indices))
     data, label = self.get(record, segment_indices)
     print('data : {}'.format(data.size()))
     print('label : {}'.format(label)) 
     return self.get(record, segment_indices)`
    
  2. 打印dataloader返回值如下:

for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) target = target.cuda() input_var = torch.autograd.Variable(input) target_var = torch.autograd.Variable(target) print('\n') print('input_var : {}'.format(input_var.size())) print('target_var : {}'.format(target_var))

  1. 实际结果如下
    企业微信截图_16431896458396
    我检查了很多次, 发现两者的返回都是对不上的,请问可能时什么原因呢