SiamRPN 复现

基本介绍

我我我,本来已经写一半了,然后一通瞎逼操作,现在重新写。

官方的siamrpn,siamrpn++,dasiamrpn,siammask都在这个github项目中:https://github.com/STVIR/pysot

siamrpn就是siamese网络加RPN网络,siamese做特征提取,然后模板特征与搜索区域的特征做相关操作,得到的score map送入RPN网络。

RPN网络有两个分支两个一个用来做分类,一个用来做位置和大小的逻辑回归。可以看我的那一篇siamRPN论文阅读。

复现包含:数据处理,模型搭建,loss定义,训练,预测,评价。

最后我会把完整的带注释的复现项目放到github上。

github项目地址:https://github.com/Etherwave/SiamRPN_rebuild

1 数据处理

我们这里使用VOT2019的数据来进行训练。

VOT2019的gt是4个坐标的四边形,我们需要转化成需要的矩形框。

如博客dataset/vot中提到的方式。

将红色的四边形,先转成蓝色的包围框,然后根据面积编程绿色的一个更加合适的矩形框。(没有为什么,siamrpn代码里是这样搞的)

如图为siamrpn的网络图,可知我们需要一个255 * 255的搜索区域图像和127 * 127大小的模板图像,但是我们的gt是矩形,如果直接强行转换为255的正方形,势必会使模板的长宽比变化,不利于我们特征网络的训练,所以我们需要将矩形的gt先转化为合适的正方形,然后再将这个正方形放缩到255大小,这样可以保证目标长宽比不变。

原作者代码先是将数据集的图像裁处511大小的一个正方形图存起来,然后训练的时候读取这个511大小图,然后再截取255大小的搜索区域,和127大小的目标,因为获取这两个图像的时候加入了图像增强的平移和放缩,所以每次都要动态的裁剪,会使训练的更好,但是这样我们就不能先将图像裁剪好放在那里等着用了。

这个操作我感觉就有点迷,你何必先裁剪成511大小的呢,大部分的图像都不大,你都存成511的正方形即没有说裁剪的次数少了,也没有说新处理的数据集占用的空间少了。那么何必不直接在原图像上裁剪呢??

还有关于正负样本的设定,论文中定义iou > 0.6的anchor定义为正样本,iou < 0.3的anchor定义为负样本,一个搜索图像最多找64个样本,最多16个正样本,即正负样本的比是1:3,但是原作者的代码中还定义了另一个负样本,就是当模板图像和搜索图像不是同一个视频中的图像时,也是负样本,这个比例是0.2,就是说,10个数据对(模板图像和搜索图像),有8个是从同一个视频中挑的两帧,还有2个是从不同视频中挑的两帧。一个负样本的搜索图像也最多选择16个负样本,并且原作者是集中在图像中心选的。

挑正样本他也的确按照论文中所说的,两帧之差不超过FRAME_RANGE帧,有的是3帧,也有的是100帧,这个感觉没啥用,他用的数据集都是分类的数据集训练的,训练处理的RPN的cls那个分支学会的就是模板图像和搜索图像中的是不是一个类别而已,而不是目标跟踪所真正需要的判断两个目标的相似度,所以我取消了这个帧数限定,只要是同一个视频中的两帧我都算正样本。(我在毕设中将score的定义从二分类问题,改为一个预测两个目标的相似度,也的确提升了算法性能)

负样本的选取作者也是很佛系,先由那个dataloader不是传过来一个index嘛,找到一个这个index对应的视频,然后再随机random一个视频,再在这个视频中随机random一帧,对无法保证random到的不是一个视频的,甚至无法保证不是同一帧。所以我这里按照毕设那样加一个used_out标签。

整体思路就确定下来了,

1 写一个video类,管理一个视频的每一帧,帧的数量,每一帧是否用过了,gt,每一帧的路径,(类中不读取图片,仅仅存放图像的路径,到dataloader要给模型传数据的时候,再由dataloader去读取图片),

2 再写一个subdataset类,管理所要用到的所有数据集,如VOT2019,VOT2018等,里边实现读取所有数据集下的所有视频,实现一些功能如读取一对正样本,读取一对负样本等.

3 写一个mydataset类,里边存放要用到的数据集的路径,给subdataset让subdataset获取所有视频,最后要把mydataset给dataloader,所以要实现__len__函数,__getitem__函数。并且还要将subdataset传过来的一对图像路径和对应的gt进行处理。传过来的图像路径要进行图像读取,然后根据gt裁剪成合适的大小,进行图像增强,并且还要生成label。分类的cls label,回归的xywh。

为了测试代码正确性,还要写一个画图函数,验证没有裁剪错地方,label没有生成错误。

mydataset中的__getitem__函数的思路即:读取图片,图像增强,构建label。

详细细节见上边的github代码

2 模型搭建

very simple 见代码

3 loss定义

loss分为分类loss和逻辑回归loss

cls_loss 还有 loc_loss

在原版中cls_loss用的是交叉熵损失,我个人感觉不好,所以自定义了loss

$$ cls\_loss=y(predict\_y-0)^2+(1-y)(predict\_y)^2 $$

逻辑回归的loss就是差的绝对值,乘以delta_weight的权重

详细见代码

4 训练

emmm,常规代码,没啥技巧

见代码

5 预测

预测一般是写一个tracker类,来将第一步的模板提取特征和后边的跟踪分开,这样模板的特征提取一次存储起来就好了,快一点

将预测出来的cls_score还要处理一下,加上面积变化惩罚和长宽比惩罚,从而确定最好的anchor。

所有预测出的anchor的xywh要还原为真实大小的xywh。

根据最好的anchor的位置,算出来一个新的目标偏移,将这个偏移和上一帧的位置计算一个新位置。

详细见代码

文章目录