yqyao / flag

train own dataset with SSD

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

title date tags categories
flag
2018-02-28 03:13:17 -0800
DL
个人学习

SSD 训练自己的数据集

因为项目需要,需要检测一些不在公开数据集中的物体,比如旗帜,横幅。这些在VOC以及COCO,Imagenet中都没有对应的类别,因此我们需要自己整理数据,标注数据。SSD作为一个快速的检测框架,非常适合实际使用,因为之前已经有使用SSD框架的行人检测,对应的SDK已经其他一些API调用工具都已经编写完成。

数据准备

旗帜数据非常多,目前的搜索引擎有好几个,因此需要从不同的引擎上去搜取图片,需要注意的是,需要下载原图,有些引擎可能稍不注意下载的都是缩略图,尺寸非常小,不适合训练。标注工具使用Labelimg https://github.com/tzutalin/labelImg ,自动生成xml格式的label,非常好用。网上对应的资料非常多,非常好用的一个工具,只是暂时不支持mac,这个是个缺点。

网络结构设计

因为实际使用时不需要检测太小的旗帜,因此网络可以稍微小一点,同时可以对原始的网络结构进行一些修改方便输入不同尺寸的图片。这次使用的网络结构是resnet18,相比vgg,计算量已经小了很多,运算速度也还算是比较快,在显卡为 GTX 1080 Ti, 输入300x300可以达到80fps,这个实际使用完全够用。

github 地址为:

遇到的坑

Check failed:std::equal(top_shape.begin() + 1, top_shape.begin() + 4, shape.begin() + 1)

or

OpenCV Error: Assertion failed ((scn == 3 || scn == 4) && (depth == CV_8U || depth == CV_32F)) in cvtColor

这个错误的原因有多个:

  • 训练图像通道不为3,搜集的图片可能包含灰度图,以及4通道的图片
  • 训练图像最小边过小,比如小于300

注意python的cv2模块与c++直接调用C++是有区别的

python 的cv2 是无法识别灰度以及4通道图片的,以默认的方式读图片都会以3通道的方式读取,如果事先不知道图片是否是灰度图是无法通过python cv2直接判断出图片是不是灰度图或者4通道图片,而caffe中C++是直接可以读出来的。

对于灰度图python 可以通过非常暴力的判断rgb3个通道是不是相同去判断是不是灰度图,而4通道图就比较麻烦,我没有去找图片,而是更加暴力去重新读取图片,然后使用python cv2去写入图片保证最后的图片都是3通道。

另一种解决办法

将P.Resize.warp 改为 P.Resize.FIT_SMALL_SIZE,但是这种方式就没办法批量训练,batch size只能设置为1。实际上不是一个非常好的训练方式。

About

train own dataset with SSD