RyanCCC / DETR

DETR : End-to-End Object Detection with Transformers (Tensorflow)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DETR Tensorflow

DETR : End-to-End Object Detection with Transformers:将Transformer应用于目标检测。Pytorch版本的实现:facebookresearch/detr。本仓库基于Tensorflow实现DETR,包括训练代码、推理代码以及finetune代码。主要参考: detr-tensorflow。DETR网络结构图如下所示:

项目架构

├─data:数据集基本操作
├─detr:DETR网络实现
│  ├─loss:损失函数
│  └─networks:主要网络实现代码
├─logger:日志脚本
├─notebooks:介绍说明的Jupyter notebook
└─src:一些资源文件,如readme的图像

介绍说明

模型训练

训练coco数据集,数据文件架构如下:

  • data_dir:coco数据集根目录
  • img_dir:训练集和验证集图像文件夹
  • ann_file:训练集和验证集图像标注文件夹

执行命令:python train_coco.py --data_dir /path/to/COCO --batch_size 8 --target_batch 32 --log

预训练模型下载:

  1. pytorch预训练模型下载仓库:https://github.com/facebookresearch/detr

  2. tensorflow预训练模型下载仓库:https://github.com/Leonardo-Blanger/detr_tensorflow

模型微调

微调的基本流程:

# Load the pretrained model
detr = get_detr_model(config, include_top=False, nb_class=3, weights="detr", num_decoder_layers=6, num_encoder_layers=6)
detr.summary()

# Load your dataset
train_dt, class_names = load_tfcsv_dataset(config, config.batch_size, augmentation=True)

# Setup the optimziers and the trainable variables
optimzers = setup_optimizers(detr, config)

# Train the model
training.fit(detr, train_dt, optimzers, config, epoch_nb, class_names)

Pacal VOC数据集

目录结构如下:

  • data_dir:数据集根目录
  • img_dir:数据集的图像
  • ann_file:数据集标注文件

执行命令:python finetune_voc.py --data_dir /home/thibault/data/VOCdevkit/VOC2012 --img_dir JPEGImages --ann_dir Annotations --batch_size 8 --target_batch 32 --log

hardhatcsv数据集

目录结构如下:

  • data_dir:数据集根目录
  • img_dir:数据集的图像
  • ann_file:数据集标注文件

执行命令:python finetune_hardhat.py --data_dir /home/thibault/data/hardhat --batch_size 8 --target_batch 32 --log

模型评估

测试集数据目录结构如下:

  • data_dir:测试集根目录
  • img_dir:测试集的图像
  • ann_file:测试集Ground True

执行命令:python eval.py --data_dir /path/to/coco/dataset --img_dir val2017 --ann_file annotations/instances_val2017.json

About

DETR : End-to-End Object Detection with Transformers (Tensorflow)

License:Apache License 2.0


Languages

Language:Jupyter Notebook 90.2%Language:Python 9.8%