关于pytorch加载imageNet
zhjpqq opened this issue · comments
zhjpqq commented
您好,我用pytorch训练resnet,但是加载ImageNet太慢了,想用hdf5来替换原有的DataSet。
但是看了你的博客,仍然一头雾水,不知用hdf5之后,速度能否提升,以及该如何写代码。
我参考的博文:
https://www.jianshu.com/p/19f3ca564644
https://www.cnblogs.com/nwpuxuezha/p/6537307.html
不知能否指点一二,可酬劳,
这个是我现在的加载代码,基本按官方的进行的。
# data
traindir = os.path.join(cfg.data_root, 'train')
valdir = os.path.join(cfg.data_root, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.data_workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=cfg.batch_size_val, shuffle=False,
num_workers=cfg.data_workers, pin_memory=False)
JingYu Ji commented
不一定有必要。关于加载 imagenet 太慢的问题,只看这一段代码问题不大。开了多少进程?io,cpu 负载,gpu 负载如何?建议用 profile 工具分析一下性能瓶颈。