Fast RCNN 训练自己的数据集(3训练和检测)

整理文档很辛苦,赏杯茶钱您下走!

免费阅读已结束,点击下载阅读编辑剩下 ...

阅读已结束,您可以下载文档离线阅读编辑

资源描述

FastRCNN训练自己的数据集(3训练和检测)转载请注明出处,楼燚(yì)航的blog,://github.com/YihangLou/fast-rcnn-train-another-dataset这是我在github上修改的几个文件的链接,求星星啊,求星星啊(原谅我那么不要脸~~)在之前两篇文章中我介绍了怎么编译FastRCNN,和怎么修改FastRCNN的读取数据接口,接下来我来说明一下怎么来训练网络和之后的检测过程先给看一下极好的检测效果预训练模型介绍首先在data目录下,有两个目录就是之前在1中解压好fast_rcnn_models/imagenet_models/fast_rcnn_model文件夹下面是作者用fastrcnn训练好的三个网络,分别对应着小、中、大型网络,大家可以试用一下这几个网络,看一些检测效果,他们训练都迭代了40000次,数据集都是pascal_voc的数据集。caffenet_fast_rcnn_iter_40000.caffemodelvgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodelvgg16_fast_rcnn_iter_40000.caffemodelimagenet_model文件夹下面是在Imagenet上训练好的通用模型,在这里用来初始化网络的参数CaffeNet.v2.caffemodelVGG_CNN_M_1024.v2.caffemodelVGG16.v2.caffemodel在这里我比较推荐先用中型网络训练,中型网络训练和检测的速度都比较快,效果也都比较理想,大型网络的话训练速度比较慢,我当时是5000多个标注信息,网络配置默认,中型网络训练大概两三个小时,大型网络的话用十几个小时,需要注意的是网络训练最好用GPU,CPU的话太慢了,我当时用的实验室的服务器,有16块TeslaK80,用起来真的是灰常爽!2.修改模型文件配置模型文件在models下面对应的网络文件夹下,在这里我用中型网络的配置文件修改为例子比如:我的检测目标物是car,那么我的类别就有两个类别即background和car因此,首先打开网络的模型文件夹,打开train.prototxt修改的地方重要有三个分别是个地方首先在data层把num_classes从原来的21类20类+背景,改成2类车+背景接在在cls_score层把num_output从原来的21改成2在bbox_pred层把num_output从原来的84改成8,为检测类别个数乘以4,比如这里是2类那就是2*4=8OK,如果你要进一步修改网络训练中的学习速率,步长,gamma值,以及输出模型的名字,需要在同目录下的solver.prototxt中修改。如下图:train_net:models/VGG_CNN_M_1024/train.prototxtbase_lr:0.001lr_policy:stepgamma:0.1stepsize:30000display:20average_loss:100momentum:0.9weight_decay:0.0005#Wedisablestandardcaffesolversnapshottingandimplementourownsnapshot#functionsnapshot:0#Westillusethesnapshotprefix,thoughsnapshot_prefix:vgg_cnn_m_1024_fast_rcnn#debug_info:true3.启动FastRCNN网络训练启动训练:./tools/train_net.py--gpu11--solvermodels/VGG_CNN_M_1024_LOUYIHANG/solver.prototxt--weightsdata/imagenet_models/VGG_CNN_M_1024.v2.caffemodel--imdbKakouTrain参数讲解:这里的--是两个-,markdown写的,大家不要输错train_net.py是网络的训练文件,之后的参数都是附带的输入参数--gpu代表机器上的GPU编号,如果是nvidia系列的tesla显卡,可以在终端中输入nvidia-smi来查看当前的显卡负荷,选择合适的显卡--solver代表模型的配置文件,train.prototxt的文件路径已经包含在这个文件之中--weights代表初始化的权重文件,这里用的是Imagenet上预训练好的模型,中型的网络我们选择用VGG_CNN_M_1024.v2.caffemodel--imdb这里给出的训练的数据库名字需要在factory.py的__sets中,我在文件里面有__sets['KakouTrain'],train_net.py这个文件会调用factory.py再生成kakou这个类,来读取数据4.启动FastRCNN网络检测我修改了tools下面的demo.py这个文件,用来做检测,并且将检测的坐标结果输出到相应的txt文件中可以看到原始的demo.py是用网络测试了两张图像,并做可视化输出,有具体的检测效果,但是我是在Linux服务器的终端下,没有displaydevice,因此部分代码要少做修改下面是原始的demo.py:#!/usr/bin/envpython#--------------------------------------------------------#FastR-CNN#Copyright(c)2015Microsoft#LicensedunderTheMITLicense[seeLICENSEfordetails]#WrittenbyRossGirshick#--------------------------------------------------------Demoscriptshowingdetectionsinsampleimages.SeeREADME.mdforinstallationinstructionsbeforerunning.import_init_pathsfromfast_rcnn.configimportcfgfromfast_rcnn.testimportim_detectfromutils.cython_nmsimportnmsfromutils.timerimportTimerimportmatplotlib.pyplotaspltimportnumpyasnpimportscipy.ioassioimportcaffe,os,sys,cv2importargparseCLASSES=('__background__','aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor')NETS={'vgg16':('VGG16','vgg16_fast_rcnn_iter_40000.caffemodel'),'vgg_cnn_m_1024':('VGG_CNN_M_1024','vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),'caffenet':('CaffeNet','caffenet_fast_rcnn_iter_40000.caffemodel')}defvis_detections(im,class_name,dets,thresh=0.5):Drawdetectedboundingboxes.inds=np.where(dets[:,-1]=thresh)[0]iflen(inds)==0:returnim=im[:,:,(2,1,0)]fig,ax=plt.subplots(figsize=(12,12))ax.imshow(im,aspect='equal')foriininds:bbox=dets[i,:4]score=dets[i,-1]ax.add_patch(plt.Rectangle((bbox[0],bbox[1]),bbox[2]-bbox[0],bbox[3]-bbox[1],fill=False,edgecolor='red',linewidth=3.5))ax.text(bbox[0],bbox[1]-2,'{:s}{:.3f}'.format(class_name,score),bbox=dict(facecolor='blue',alpha=0.5),fontsize=14,color='white')ax.set_title(('{}detectionswith''p({}|box)={:.1f}').format(class_name,class_name,thresh),fontsize=14)plt.axis('off')plt.tight_layout()plt.draw()defdemo(net,image_name,classes):Detectobjectclassesinanimageusingpre-computedobjectproposals.#Loadpre-computedSelectedSearchobjectproposalsbox_file=os.path.join(cfg.ROOT_DIR,'data','demo',image_name+'_boxes.mat')obj_proposals=sio.loadmat(box_file)['boxes']#Loadthedemoimageim_file=os.path.join(cfg.ROOT_DIR,'data','demo',image_name+'.jpg')im=cv2.imread(im_file)#Detectallobjectclassesandregressobjectboundstimer=Timer()timer.tic()scores,boxes=im_detect(net,im,obj_proposals)timer.toc()print('Detectiontook{:.3f}sfor''{:d}objectproposals').format(timer.total_time,boxes.shape[0])#VisualizedetectionsforeachclassCONF_THRESH=0.8NMS_THRESH=0.3forclsinclasses:cls_ind=CLASSES.index(cls)cls_boxes=boxes[:,4*cls_ind:4*(cls_ind+1)]cls_scores=scores[:,cls_ind]dets=np.hstack((cls_boxes,cls_scores[:,np.newaxis])).astype(np.float32)keep=nms(dets,NMS_THRESH)dets=dets[keep,:]print'All{}detectionswithp({}|box)={:.1f}'.format(cls,cls,CONF_THRESH)vis_detections(im,cls,dets,thresh=CONF_THRESH)defparse_args():Parseinputarguments.parser=argpar

1 / 28
下载文档,编辑使用

©2015-2020 m.777doc.com 三七文档.

备案号:鲁ICP备2024069028号-1 客服联系 QQ:2149211541

×
保存成功