HighwayWu / ImageForensicsOSN

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

关于train.py

linohzz opened this issue · comments

commented

作者您好,我使用20张图片和对应的mask用以下代码生成train.npy用作试验所编写的npy文件是否符合格式,但遇到了报错.
1
2
3
这是我完整的生成npy的代码
`import numpy as np

假设的输入文件和真值文件路径列表

input_paths = ['/home/linux/code/lql/dataset/season3_data/train/1.jpg', '/home/linux/code/lql/dataset/season3_data/train/2.jpg', '/home/linux/code/lql/dataset/season3_data/train/3.jpg', '/home/linux/code/lql/dataset/season3_data/train/4.jpg', '/home/linux/code/lql/dataset/season3_data/train/5.jpg', '/home/linux/code/lql/dataset/season3_data/train/6.jpg', '/home/linux/code/lql/dataset/season3_data/train/7.jpg', '/home/linux/code/lql/dataset/season3_data/train/8.jpg', '/home/linux/code/lql/dataset/season3_data/train/9.jpg', '/home/linux/code/lql/dataset/season3_data/train/10.jpg', '/home/linux/code/lql/dataset/season3_data/train/11.jpg', '/home/linux/code/lql/dataset/season3_data/train/12.jpg', '/home/linux/code/lql/dataset/season3_data/train/13.jpg', '/home/linux/code/lql/dataset/season3_data/train/14.jpg', '/home/linux/code/lql/dataset/season3_data/train/15.jpg', '/home/linux/code/lql/dataset/season3_data/train/16.jpg', '/home/linux/code/lql/dataset/season3_data/train/17.jpg', '/home/linux/code/lql/dataset/season3_data/train/18.jpg', '/home/linux/code/lql/dataset/season3_data/train/19.jpg', '/home/linux/code/lql/dataset/season3_data/train/20.jpg'] # 示例路径
ground_truth_paths = ['/home/linux/code/lql/dataset/season3_data/train_mask/1.png', '/home/linux/code/lql/dataset/season3_data/train_mask/2.png', '/home/linux/code/lql/dataset/season3_data/train_mask/3.png', '/home/linux/code/lql/dataset/season3_data/train_mask/4.png', '/home/linux/code/lql/dataset/season3_data/train_mask/5.png', '/home/linux/code/lql/dataset/season3_data/train_mask/6.png', '/home/linux/code/lql/dataset/season3_data/train_mask/7.png', '/home/linux/code/lql/dataset/season3_data/train_mask/8.png', '/home/linux/code/lql/dataset/season3_data/train_mask/9.png', '/home/linux/code/lql/dataset/season3_data/train_mask/10.png', '/home/linux/code/lql/dataset/season3_data/train_mask/11.png', '/home/linux/code/lql/dataset/season3_data/train_mask/12.png', '/home/linux/code/lql/dataset/season3_data/train_mask/13.png', '/home/linux/code/lql/dataset/season3_data/train_mask/14.png', '/home/linux/code/lql/dataset/season3_data/train_mask/15.png', '/home/linux/code/lql/dataset/season3_data/train_mask/16.png', '/home/linux/code/lql/dataset/season3_data/train_mask/17.png', '/home/linux/code/lql/dataset/season3_data/train_mask/18.png', '/home/linux/code/lql/dataset/season3_data/train_mask/19.png', '/home/linux/code/lql/dataset/season3_data/train_mask/20.png', ] # 示例路径

创建格式为[(input1, ground_truth1), (input2, ground_truth2), ...]的列表

data = [(input_paths[i], ground_truth_paths[i]) for i in range(len(input_paths))]

保存为.npy文件

np.save('train.npy', data)`