mindspore-lab / mindocr

A toolbox of OCR models, algorithms, and pipelines based on MindSpore

Home Page:https://mindspore-lab.github.io/mindocr/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: The 'getitem' operation does not support the type [None, Int64].

sevennotmouse opened this issue · comments

2023昇腾AI创新大赛-算法创新-VisionLAN模型迁移复现
background:我们正在将pytorch源代码中的train_LF_1.py迁移至mindspore。遵循mindspore的训练范式,我们分别定义好模型、数据集、优化器、损失函数等,通过构造WithLossCell、调用Model类,使用model.train进行训练。


遇到的问题:训练时报了如下错误,是关于getitem这个方法的:
RuntimeError: The 'getitem' operation does not support the type [None, Int64].The supported types of overload function getitem is: [Tuple, Slice], [List, Slice], [Tensor, Ellipsis], [Tuple, Tensor], [List, Number], [Tensor, Slice], [Dictionary, String], [Tensor, Tensor], [String, Number], [Tensor, Tuple], [Tensor, None], [Tuple, Number], [Tensor, Number], [Tensor, List].
注:调试环境:华为云modelarts平台的notebook,镜像:mindspore_1.10.0-cann_6.0.1-py_3.7-euler_2.8.3,规格:Ascend: 1*Ascend910|ARM: 24核 96GB


getitem是自定义的lmdbDataset类里定义的一个方法:

class lmdbDataset():
   def __init__(xx):
		xxx
   def __fromwhich__(xx):
		xxx
   def keepratio_resize(xx)
		xxx
   def __len__(self):
        return self.nSamples
   def __getitem__(self, index):
        fromwhich = self.__fromwhich__()
        if self.global_state == 'Train':
            index = random.randint(0,self.maxlen - 1)
        index = index % self.lengths[fromwhich]
        assert index <= len(self), 'index range error'
        index += 1
        with self.envs[fromwhich].begin(write=False) as txn:
            img_key = 'image-%09d' % index
            try:
                imgbuf = txn.get(img_key.encode())
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                img = Image.open(buf).convert('RGB')
            except:
                print('Corrupted image for %d' % index)
                return self[index + 1]
            label_key = 'label-%09d' % index
            # label = str(txn.get(label_key.encode()))
            # if python3
            label = str(txn.get(label_key.encode()), 'utf-8')
            label = re.sub('[^0-9a-zA-Z]+', '', label)
            
            if (len(label) > 25 or len(label) <= 0) and self.global_state == 'Train':
                print(len(label))
                print(label)
                print('sample too long')
                print(self.global_state)
                return self[index + 1]
            
            img = self.keepratio_resize(img, self.global_state)
            if self.transform:
                img = self.transform(img)
            # generate masked_id masked_character remain_string
            label_res, label_sub, label_id =  des_orderlabel(label)
            sample = {'image': img, 'label': label, 'label_res': label_res, 'label_sub': label_sub, 'label_id': label_id}
            #return sample
            return (img,label,label_res,label_sub,label_id)  # 返回元组类型 

我们通过load_dataset函数加载数据集,其中包含调用lmdbDataset类、用mindspore的GeneratorDataset加载数据集和按batchsize划分数据集三步:

def load_dataset():
    # 调用lmdbDataset类
    train_data_set = cfgs.dataset_cfgs['dataset_train'](**cfgs.dataset_cfgs['dataset_train_args'])
#也即train_data_set = lmdbDataset(roots=['./datasets/train/SynthText','./datasets/train/MJSynth',], 
                              #  img_height = 64, img_width = 256,transform=dataset.transforms.Compose([vision.ToTensor()]), global_state='Train')
    # 用GeneratorDataset加载数据集
    train_loader = ds.GeneratorDataset(train_data_set, column_names=["image","label","label_res","label_sub","label_id"],
                                       num_parallel_workers=32,shuffle=True)
    # 按batchsize划分
    train_loader = train_loader.batch(batch_size=384)  
    
    test_data_set = cfgs.dataset_cfgs['dataset_test'](**cfgs.dataset_cfgs['dataset_test_args'])
    test_loader = ds.GeneratorDataset(test_data_set, column_names=["image","label"],num_parallel_workers=16,shuffle=False)
    test_loader = test_loader.batch(batch_size=64)
    
    return train_data_set, train_loader, test_data_set, test_loader

加载训练集,并传到model.train里

    # 加载训练集、测试集
    train_data_set, train_loader, test_data_set, test_loader = load_dataset()

    # 定义多标签损失函数
    loss = VisionLAN_Loss()     #自定义的
    # loss = nn.SoftmaxCrossEntropyWithLogits() #等效的
    
    # 定义损失网络,连接前向网络和多标签损失函数
    loss_net = CustomWithLossCell(net, loss)    
    
    # 定义Model,多标签场景下Model无需指定损失函数
    model = Model(network=loss_net, optimizer=optimizer)
    
    # 模型训练
    model.train(epoch=8, train_dataset=train_loader, callbacks=[LossMonitor()])

我们不清楚是getitem方法本身写法出现了问题,还是在哪里调用的时候传入了异常的数据类型,特别对于[None, Int64]中None的由来十分疑惑。现求助于各位专家、同行,期待百忙之中的回复,非常感谢!!!

@sevennotmouse
您好,感谢您的反馈。抱歉回复较晚,不知道您是否已经解决上述问题?

The supported types of overload function getitem is: [Tuple, Slice], [List, Slice], [Tensor, Ellipsis], [Tuple, Tensor], [List, Number], [Tensor, Slice], [Dictionary, String], [Tensor, Tensor], [String, Number], [Tensor, Tuple], [Tensor, None], [Tuple, Number], [Tensor, Number], [Tensor, List]. RuntimeError: The 'getitem' operation does not support the type [None, Int64].

参考您提供的错误日志局部,该问题应该不是由自定义数据集类lmdbDataset的__getitem__方法引入,而是由于算子getitem被输入了不支持的数据类型导致的。猜测是由于您在建模网络时,使用了不符合预期的语法导致的。建议您尝试:

  1. 使用MindSpore的PyNative模式,进行单步调试,定位引入该问题的代码;
import mindspore as ms
ms.set_context(mode=ms.PYNATIVE_MODE)
  1. 或者,对网络建模代码进行单元测试,定位引入该问题的代码。

此外,在开发自定义数据集类的过程中,如怀疑存在问题,您可以尝试对lmdbDataset类进行单元测试,确认该类的行为是否符合预期。MindOCR项目中也实现了一些自定义数据集类(包括类似的lmdbDataset类),相关代码可供参考。
建议您尝试使用新的MindSpore r2.2.11,由于MindSpore版本升级时,部分API的行为可能变更,建议您参考官网的技术文档。