RuntimeError: The 'getitem' operation does not support the type [None, Int64].
sevennotmouse opened this issue · comments
2023昇腾AI创新大赛-算法创新-VisionLAN模型迁移复现
background:我们正在将pytorch源代码中的train_LF_1.py迁移至mindspore。遵循mindspore的训练范式,我们分别定义好模型、数据集、优化器、损失函数等,通过构造WithLossCell、调用Model类,使用model.train进行训练。
遇到的问题:训练时报了如下错误,是关于getitem这个方法的:
RuntimeError: The 'getitem' operation does not support the type [None, Int64].The supported types of overload function getitem
is: [Tuple, Slice], [List, Slice], [Tensor, Ellipsis], [Tuple, Tensor], [List, Number], [Tensor, Slice], [Dictionary, String], [Tensor, Tensor], [String, Number], [Tensor, Tuple], [Tensor, None], [Tuple, Number], [Tensor, Number], [Tensor, List].
注:调试环境:华为云modelarts平台的notebook,镜像:mindspore_1.10.0-cann_6.0.1-py_3.7-euler_2.8.3,规格:Ascend: 1*Ascend910|ARM: 24核 96GB
getitem是自定义的lmdbDataset类里定义的一个方法:
class lmdbDataset():
def __init__(xx):
xxx
def __fromwhich__(xx):
xxx
def keepratio_resize(xx)
xxx
def __len__(self):
return self.nSamples
def __getitem__(self, index):
fromwhich = self.__fromwhich__()
if self.global_state == 'Train':
index = random.randint(0,self.maxlen - 1)
index = index % self.lengths[fromwhich]
assert index <= len(self), 'index range error'
index += 1
with self.envs[fromwhich].begin(write=False) as txn:
img_key = 'image-%09d' % index
try:
imgbuf = txn.get(img_key.encode())
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
except:
print('Corrupted image for %d' % index)
return self[index + 1]
label_key = 'label-%09d' % index
# label = str(txn.get(label_key.encode()))
# if python3
label = str(txn.get(label_key.encode()), 'utf-8')
label = re.sub('[^0-9a-zA-Z]+', '', label)
if (len(label) > 25 or len(label) <= 0) and self.global_state == 'Train':
print(len(label))
print(label)
print('sample too long')
print(self.global_state)
return self[index + 1]
img = self.keepratio_resize(img, self.global_state)
if self.transform:
img = self.transform(img)
# generate masked_id masked_character remain_string
label_res, label_sub, label_id = des_orderlabel(label)
sample = {'image': img, 'label': label, 'label_res': label_res, 'label_sub': label_sub, 'label_id': label_id}
#return sample
return (img,label,label_res,label_sub,label_id) # 返回元组类型
我们通过load_dataset函数加载数据集,其中包含调用lmdbDataset类、用mindspore的GeneratorDataset加载数据集和按batchsize划分数据集三步:
def load_dataset():
# 调用lmdbDataset类
train_data_set = cfgs.dataset_cfgs['dataset_train'](**cfgs.dataset_cfgs['dataset_train_args'])
#也即train_data_set = lmdbDataset(roots=['./datasets/train/SynthText','./datasets/train/MJSynth',],
# img_height = 64, img_width = 256,transform=dataset.transforms.Compose([vision.ToTensor()]), global_state='Train')
# 用GeneratorDataset加载数据集
train_loader = ds.GeneratorDataset(train_data_set, column_names=["image","label","label_res","label_sub","label_id"],
num_parallel_workers=32,shuffle=True)
# 按batchsize划分
train_loader = train_loader.batch(batch_size=384)
test_data_set = cfgs.dataset_cfgs['dataset_test'](**cfgs.dataset_cfgs['dataset_test_args'])
test_loader = ds.GeneratorDataset(test_data_set, column_names=["image","label"],num_parallel_workers=16,shuffle=False)
test_loader = test_loader.batch(batch_size=64)
return train_data_set, train_loader, test_data_set, test_loader
加载训练集,并传到model.train里
# 加载训练集、测试集
train_data_set, train_loader, test_data_set, test_loader = load_dataset()
# 定义多标签损失函数
loss = VisionLAN_Loss() #自定义的
# loss = nn.SoftmaxCrossEntropyWithLogits() #等效的
# 定义损失网络,连接前向网络和多标签损失函数
loss_net = CustomWithLossCell(net, loss)
# 定义Model,多标签场景下Model无需指定损失函数
model = Model(network=loss_net, optimizer=optimizer)
# 模型训练
model.train(epoch=8, train_dataset=train_loader, callbacks=[LossMonitor()])
我们不清楚是getitem方法本身写法出现了问题,还是在哪里调用的时候传入了异常的数据类型,特别对于[None, Int64]中None的由来十分疑惑。现求助于各位专家、同行,期待百忙之中的回复,非常感谢!!!
@sevennotmouse
您好,感谢您的反馈。抱歉回复较晚,不知道您是否已经解决上述问题?
The supported types of overload function getitem is: [Tuple, Slice], [List, Slice], [Tensor, Ellipsis], [Tuple, Tensor], [List, Number], [Tensor, Slice], [Dictionary, String], [Tensor, Tensor], [String, Number], [Tensor, Tuple], [Tensor, None], [Tuple, Number], [Tensor, Number], [Tensor, List]. RuntimeError: The 'getitem' operation does not support the type [None, Int64].
参考您提供的错误日志局部,该问题应该不是由自定义数据集类lmdbDataset的__getitem__
方法引入,而是由于算子getitem
被输入了不支持的数据类型导致的。猜测是由于您在建模网络时,使用了不符合预期的语法导致的。建议您尝试:
- 使用MindSpore的PyNative模式,进行单步调试,定位引入该问题的代码;
import mindspore as ms
ms.set_context(mode=ms.PYNATIVE_MODE)
- 或者,对网络建模代码进行单元测试,定位引入该问题的代码。
此外,在开发自定义数据集类的过程中,如怀疑存在问题,您可以尝试对lmdbDataset类进行单元测试,确认该类的行为是否符合预期。MindOCR项目中也实现了一些自定义数据集类(包括类似的lmdbDataset类),相关代码可供参考。
建议您尝试使用新的MindSpore r2.2.11,由于MindSpore版本升级时,部分API的行为可能变更,建议您参考官网的技术文档。