这是论文Unsupervised Domain Adaptation by Backpropagation的复现代码,并完成了MNIST与MNIST-M数据集之间的迁移训练
- tensorflow 1.14.0
- opencv 3.4.5.20
- numpy 1.18.1
首先下载MNIST数据集,放在项目文件的/dataset/mnist子文件夹下。之后下载BSDS500数据集放在项目文件的/data/BSR_bsds500.tgz路径下。
之后为了生成MNSIT-M数据集,运行create_mnistm.py脚本,命令如下:
python create_mnistm.py
脚本运行结束后,MNIST-M数据集将保存在项目文件的/dataset/mnistm子文件夹下。 之后运行模型训练脚本train.py即可,运行命令为为:
python train.py
下面是训练过程中的相关tensorboard的相关指标在训练过程中的走势图。首先是训练误差的走势图,主要包括训练域分类误差、训练图像分类误差和训练总误差。
接下来是验证误差的走势图,主要包括验证域分类误差、验证图像分类误差和验证练总误差。
最后是精度走势图,主要包括训练精度和测试精度。其中训练精度是在源域数据集即MNIST数据集上的统计结果,验证精度是在目标域数据集即MNIST-M数据集上的统计结果。从图中可以看出,DANN在训练MNIST-M数据集时没有使用对应的标签,MNSIT-M数据集上的精度最终收敛到75.4%,效果相比于81.49%还有一定距离,但鉴于没有使用任何数据增强和dropout,这个结果可以接受。