基于知识蒸馏的YOLOv3算法研究

2022-09-06 11:08李姜楠刘竞升王洪刚
计算机工程与应用 2022年17期
关键词:特征提取损失特征

李姜楠,伍 星,刘竞升,王洪刚

重庆大学 计算机学院,重庆 400000

人工智能的快速发展使深度学习技术广泛应用于目标检测领域,2014 年Girshick 等人首次将深度学习应用于目标检测,提出R-CNN[1]二阶段检测模型,并在此基础上引入感兴趣区域池化(RoI 池化)得到Fast R-CNN[2],然后使用区域提取网络结构(RPN)代替选择性搜索算法生成候选框进一步得到Faster R-CNN[3]。Lin等人提出特征金字塔网络(FPN[4]),通过融合低层特征信息和高层语义信息,提升了小目标检测效果。Cai等人提出Cascade R-CNN[5],通过不断提高并交比(IoU)阈值,在保证样本数不减少的情况下训练出高质量检测器。针对二阶段检测模型预测速度较慢的问题,2016 年Redmon 等人提出YOLO[6]一阶段检测模型,然后在此基础上将全连接层替换成卷积层,并利用聚类算法获得先验框而提出YOLOv2[7],接着使用更强大的特征提取网络并引入类FPN结构提出YOLOv3[8]。邹承明等人[9]在YOLOv3 基础上引入Focal loss 和GIoU loss,提高了YOLOv3对小目标的检测能力。Bochkovskiy等人将CSP 结构引入特征提取网络并在特征融合层中使用PAN结构而提出了YOLOv4[10]。Liu等人提出了可在多尺度特征图上检测的SSD[11]模型,弥补了YOLO在小物体检测上精度不佳的问题。Lin等人提出RetinaNet[12],通过引入Focal loss[12]消除了大量背景造成数据不平衡的影响,使一阶段检测模型获得了接近二阶段模型的精度。随着目标检测模型开始向移动端部署,研究人员将目光转向了参数少且内存占用低的轻量级模型。知识蒸馏作为一种有效的模型压缩方法,逐渐受到研究者的青睐。

自2015年Hinton等人提出知识蒸馏[13]后,大量研究者对图像分类领域的知识蒸馏展开研究,但在目标检测上的研究依旧较少。Chen 等人[14]打破了这个僵局,以Faster R-CNN为检测网络,对特征提取层,分类损失和回归损失同时展开蒸馏,提升了二阶段目标检测模型的精度。Wang等人[15]针对全局特征图的蒸馏算法会引入大量背景信息的问题,提出使用信号图(mask-map[15])将教师网络传递的知识限制在真实框附近,获得了更高的精度,但基于二阶段检测模型蒸馏出的网络因速度依然较慢难以在移动端部署。Mehta等人[16]将知识蒸馏应用于一阶段检测模型,通过类FPN 结构优化学生网络架构,加入无标签数据集进行训练,蒸馏时使用特征图非极大值抑制算法(FM-NMS)过滤冗余框,在tiny-yolov2基础上mAP 提升了14 个百分点,但蒸馏的提升效果有限,mAP只提升不到1个百分点。

2019年,管文杰等人[17]将知识蒸馏引入Cascade RCNN,将二、三阶段检测器回归分类的结果作为“软目标”加入损失函数中进行蒸馏,达到和四阶段检测相当的精度。同年,温静[18]将知识蒸馏的Attention转移算法与归一化后的损失函数应用于智能车驾驶环境理解中的目标检测任务,有效提高了小网络的准确性。

针对目前在一阶段目标检测器上蒸馏的研究较少,且蒸馏提升效果有限的情况,本文以YOLOv3为检测网络,提出将信息图作为监督信号在特征提取层和特征融合层上同时展开蒸馏。本文的贡献点如下:(1)将信息图作为监督信号对学生网络展开蒸馏。信息图是教师网络传递的知识重要性的分布图,不仅过滤掉了教师网络传递的背景信息,且强化了学生网络对教师网络重点知识的学习。(2)在特征提取层和特征融合层上同时展开蒸馏。在特征提取层上蒸馏后,特征融合层上的蒸馏对前者的蒸馏有一个校正作用,可以进一步提升蒸馏效果。

1 目标检测模型的蒸馏算法

Chen 等人[14]提出对二阶段目标检测器Faster R-CNN 进行蒸馏,教师网络和学生网络具有相同的检测框架,但是学生网络会选择轻量级的特征提取网络(Backbone)。如图1所示,图片分别输入至教师网络和学生网络,在特征提取层(hint)、分类结果(classification)、回归结果(regression)上产生不同的输出。通过度量教师网络和学生网络输出的差距,构建蒸馏损失。将蒸馏损失加上检测的真实损失(Detection loss)构成总损失,计算方式如式(1)所示:

式中,Lloss为总损失,Lhard为真实损失,Lsoft为蒸馏损失。s为学生网络,T为真实标签,t为教师网络,α为平衡真实损失和蒸馏损失的权重。通过反向传播算法更新学生网络的权重,不断降低损失值,可以使学生网络的输出逐渐接近教师网络。Lhard的计算方式如式(2)所示:

式中,Lcls为目标检测中的分类损失,Lreg为目标检测中的回归损失。Lsoft的计算方式如式(3)所示:

式中,Lhint为hint 损失(度量了学生网络和教师网络输出的特征层的差异,此处使用了平方差损失,即图1 中的L2 loss),Lsoft_cls为两者之间的分类损失(交叉熵损失,即图1 中的Cross Entropy Loss),Lsoft_reg为两者之间的回归损失(平方差损失)。在hint损失中,为了保持学生网络输出的特征层维度和教师网络一致,使用一个由1×1卷积组成的自适应层(adaption)调节学生网络输出的维度。

图1 二阶段检测模型蒸馏架构Fig.1 Two-stage detection model distillation architecture

Wang 等人[15]发现图1 中的hint 损失是针对特征层的全局信息展开蒸馏,会引入大量背景信息,对此提出了改进。引入信号图(mask-map[15])去掉教师网络传递的背景信息,只针对目标周围的特征层进行蒸馏(未对分类和回归信息进行蒸馏)。如图2 所示,教师网络对学生网络蒸馏时,使用了一张标识目标区域的maskmap做监督信号,虚线框描述了由教师网络输出的预测框和真实标签作为输入,获得mask-map 的过程,计算方式如式(4)所示:

图2 改进的二阶段检测模型蒸馏架构Fig.2 Improved distillation architecture of two-stage detection model

式中,Hxyz为相应位置mask-map 的取值,IoUxyz为相应位置IoU 的取值,φ为控制mask-map 范围的参数,IoU-map为教师网络输出的预测框和真实标签生成的IoU 的取值。通过max 函数获得每组IoU-map的最大值,并和φ的乘积作为阈值,小于该阈值的IoU 置为0。对每组IoU-map做或操作获得一张IoU-map,再对多组IoU-map做或操作,可获得过滤掉背景信息的maskmap。该蒸馏架构的蒸馏损失为基于IoU-map监督的教师网络和自适应层输出的平方差损失。

Mehta 等人[16]在一阶段目标检测器上进行蒸馏,针对教师网络将大量重复框传递给学生网络的问题,提出在教师网络的输出上做非极大值抑制的FM-NMS 算法。如图3 所示,图片输入至教师网络和学生网络,将教师网络的输出经过FM-NMS 过滤(对同类别预测框的得分进行排序,只留下同类别得分最高的预测框)。将过滤后预测框的置信度(Confidence)、分类(Classification)和回归(Regression)信息与学生网络输出的预测框信息计算蒸馏损失,加上学生网络的输出和真实标签产生的真实损失(Detection Loss)构成总损失,对该损失进行反向传播来训练学生网络。

图3 一阶段检测模型蒸馏架构Fig.3 One-stage detection model distillation architecture model

二阶段目标检测器中的区域提取网络可以去除大量冗余框,且蒸馏过程中,通过mask-map过滤掉了背景信息,可以达到较好的蒸馏效果。但二阶段目标检测器参数量较多,占用内存大,即使是蒸馏后的小网络,也难以在移动端部署,相比之下,蒸馏后的一阶段目标检测器可在移动端部署。因此,本文对图3中的知识蒸馏架构进行改进,并应用于一阶段目标检测器,提出基于信息图对特征提取层和特征融合层同时蒸馏的知识蒸馏架构。如图4 所示,图片分别输入教师网络和学生网络,在特征提取层(Backbone)和特征融合层(Neck)输出的特征层上同时展开蒸馏,计算损失时,引入信息图作为监督信号对蒸馏过程提供指导,最后通过反向传播算法更新学生网络的权重。

图4 基于信息图的知识蒸馏架构Fig.4 Knowledge distillation architecture based on information map

2 知识蒸馏算法改进

2.1 基于信息图的蒸馏

基于信息图的蒸馏过程分为3个步骤:(1)如图5所示,将学生网络输出的特征层输入一个由1×1卷积组成的自适应层,使学生网络输出特征层的维度和教师网络保持一致;(2)在信息图的指导下计算特征层之间的平方差损失,信息图对涵盖目标的区域均赋予了权重,非目标区域的值为0;(3)进行反向传播,只更新学生网络的权重。经过模型多次训练,学生网络的输出将越来越接近教师网络的输出。

图5 特征提取层的蒸馏Fig.5 Distillation of feature extraction laye

信息图的生成如图6所示,首先由教师网络的预测框和真实标签获得mask-map,mask-map 中包含圆圈的部分值为1,表示前景区域,未包含圆圈的部分值为0,表示背景区域。在教师网络传递的前景信息中,越靠近目标的关键部分,最终对目标的判断越具有决定性影响,但mask-map 中只有0 值和1 值,无法对不同前景信息的重要性进行区分。教师网络输出的置信度包含了前景信息的重要程度,越靠近目标核心部分的置信度越大,越远离的置信度越小。本文基于教师网络输出的置信度,对现有的mask-map进行改进,提出了可以在蒸馏过程中提供监督信号的信息图,信息图的计算方式如式(5)所示:

图6 信息图的生成Fig.6 Generation of Information-map

式中,Information-map表示信息图,confidenceteacher表示教师网络输出的置信度。将置信度和mask-map相乘获得信息图,可以在值为1的部分获得有区分度的权重,在计算损失时,权重更大的位置会受到更多关注,加强了学生网络对教师网络传递的重点知识的学习。

2.2 特征融合层的蒸馏

在网络结构中,特征提取层负责获得图片的特征信息,并以特征层的形式输出。特征融合层可以将不同尺度的特征层进行拼接或相加,获得来自不同感受野的信息,如图7所示。目前的知识蒸馏架构大多只针对特征提取层展开蒸馏,无法获得教师网络在特征融合层的知识。针对这个问题,本文提出同时对特征提取层和特征融合层展开蒸馏,相比于现有的知识蒸馏架构,可以进一步提升蒸馏效果,蒸馏中的损失函数如式(6)所示:

图7 特征融合层Fig.7 Feature fusion layer

式中,Lloss为总损失,Lhard为检测的真实损失,Lbackbone和Lneck分别为特征提取层和特征融合层产生的蒸馏损失,α和β是用来平衡三者之间的权重,这里均设置为1,Lbackbone和Lneck的计算方式如下所示:

式(7)中,N为信息图中不为0 的像素点个数的总和。W、H、C分别表示特征提取层输出的特征层的尺寸,Hxy为mask-map 的取值,conft为教师网络输出的置信度,两者相乘则为信息图,对特征层的蒸馏提供指导。fbackbone_adap(s)为学生网络特征提取层经过自适应层处理后输出的特征层,txyz为教师网络输出的特征层。基于特征融合层的蒸馏损失与特征提取层的蒸馏损失类似,如式(8)所示。

3 实验结果与分析

3.1 数据集和评价指标

本文使用的数据集为VOC[19]数据集,同时取VOC2012[19]和VOC2007[19]的训练集和验证集作为训练集,VOC2007 的测试集作为测试集。VOC 数据集共有20 个类别,本文选择的训练集有16 551 张图片,用来蒸馏出学生网络的权重参数,测试集有4 952张图片,可以测试出蒸馏的效果。评价指标为目标检测模型常用的平均类别精度(mAP),兼顾了模型的准确率和召回率,mAP的表达式为:

式中,IAP,C表示每类缺陷的平均精确度,C表示数据集的类别,N表示数据类别数目。

3.2 实验环境和训练细节

实验基于Pytorch 深度学习框架,运行在Ubuntu 18.04 系统环境下,中央处理器为4.7 GHz Intel Core™CPU i7-9700,内存为32 GB,显卡型号是NVIDIA Ge-Force RTX 2080ti,加速库为CUDA10.2和CUDNN7.6。

训练时,首先将教师网络和学生网络训练出合适的精度,然后将训练好的权重分别加载到知识蒸馏框架中,冻结教师网络的权重,只训练学生网络。蒸馏实验的训练过程中,动量和权重衰减系数分别为0.9 和0.000 5,批量大小设置为4,并交比(IoU)为0.5。采用SGD优化器进行优化,以余弦函数为周期,周期性地设置学习率,学习率的范围为[10-6,10-4]。

本文选择YOLOv3为教师网络,tiny_yolov3为学生网络。教师网络在YOLOv3 的基础上使用Focal loss、GIoU loss 和mix-up 数据增强的方式,平均精度提升至84.7%。学生网络在tiny_yolov3 的基础上加入了一个52×52的分支,便于蒸馏过程中和YOLOv3的三分支输出结构匹配。本文将Chen[14]和Wang[15]的知识蒸馏架构应用于YOLOv3,做了相关的对比实验。虽然Metha 等人[16]也将知识蒸馏应用于目标检测,在学生网络的基础上mAP提升了14个百分点,但蒸馏架构提升效果有限,mAP只提升了不到1个百分点,所以未加入对比实验。

3.3 结果与分析

实验结果如表1 所示,经过蒸馏后,学生网络mAP指标均有提升,验证了知识蒸馏可提升小网络的检测效果。(1)Chen[14]提出的知识蒸馏架构在tiny_yolov3 的基础上提升了6.2 个百分点的mAP,Wang[15]提升了6.8 个百分点的mAP,本文的知识蒸馏架构提升了9.3个百分点的mAP,验证了知识蒸馏的有效性;(2)相比现有的知识蒸馏架构,本文提出的知识蒸馏架构高出Chen[14]3.1 个百分点的mAP,且高出Wang[15]2.5 个百分点的mAP,充分验证了本文创新点的有效性,且优于现有的知识蒸馏架构;(3)教师网络和学生网络在鸟、猫、狗、沙发这四类精度的差距最大,蒸馏后这四类的精度相比其他类别均有较大提升,验证了知识蒸馏可以使学生网络获得更高的特征提取能力和检测能力,大幅度缩小学生网络和教师网络的差距。

表1 知识蒸馏实验结果Table 1 Knowledge distillation experiment results

图8 为训练时loss 的收敛曲线图,实验中使用余弦函数为周期调整学习率大小,震荡较为严重。但随着训练轮数的增加,loss 曲线逐渐趋于平缓,逐步缩小了学生网络与教师网络的差距。

图8 蒸馏训练曲线图Fig.8 Distillation training curve graph

为了验证蒸馏的有效性,图9对比了蒸馏前后的检测效果,左边一栏为学生网络的输出,右边一栏为教师网络的输出,中间一栏为蒸馏后的学生网络的输出。通过观察可发现,针对学生网络没有检测出来的框,蒸馏后的学生网络均可以检测出来,且更接近于教师网络的预测结果。

图9 蒸馏效果可视化Fig.9 Visualization of distillation results

3.4 消融实验

本文基于YOLOv3 做了相关的消融实验。为了验证该实验结果,将图1网络结构中的图片作为输入进行特征图的可视化,3.5 节也采用该图片作为输入进行蒸馏过程的可视化。

表2 显示了消融实验结果,可分为以下3 个步骤:(1)首先通过Wang[15]的知识蒸馏架构对学生网络的特征提取层展开蒸馏,mAP指标提升6.8个百分点。(2)加入信息图后,在过滤掉教师网络传递的背景信息的同时,强化了学生网络对重点知识的学习,mAP提升了1.5个百分点。如图10所示,相比于图(b3)、(c3)中的特征图强化了目标核心区域的特征,边缘区域至中心区域的重要性逐渐升高。(3)同时对特征融合层蒸馏,将教师网络在特征融合后的知识有效传递给学生网络,补充了特征提取层未传递的知识,且特征融合层中的卷积层对特征提取层的蒸馏有一个校正作用,mAP进一步提升1.0个百分点。如图10所示,相比于图(c1)中蒸馏后特征提取层的输出,图(d1)中特征融合层的输出在获得高层语义信息的同时,依旧保持了物体清晰的轮廓。综上所述,基于信息图对一阶段目标检测器的特征提取层和特征融合层同时展开蒸馏,可有效提高一阶段目标检测器的精度。

图10 蒸馏前后特征图对比Fig.10 Feature map comparison before and after distillation

表2 消融实验(mAP)Table 2 Ablation experimen(tmAP)%

3.5 蒸馏可视化

为了验证本文方法的有效性,从特征提取层、特征融合层、卷积层这3个方面对蒸馏过程进行可视化。

特征提取层的可视化。图11 显示了蒸馏前后,学生网络经自适应层输出和教师网络特征提取层输出的三张特征图。相比蒸馏前的(a1),蒸馏后的(b1)更准确地提取出目标的局部特征。相比蒸馏前的(a2),蒸馏后的(b2)更明显地划分了前景和背景区域之间的边界。相比蒸馏前的(a3),蒸馏后的(b3)与教师网络输出的特征图更接近,且更准确地突出了图片中的前景区域。

图11 特征提取层的输出Fig.11 Output of feature extraction layer

特征融合层的可视化。图12分别显示了蒸馏前后的学生网络特征融合层输出和教师网络特征融合层输出的三张特征图。相比蒸馏前的(a1)和(b3),蒸馏后(b1)和(b3)中目标的轮廓更加明显,局部特征更加清晰,更有利于小目标的检测。

图12 特征融合层的输出Fig.12 Output of feature fusion layer

卷积层的可视化。图13显示了蒸馏前后学生网络同一层卷积层输出的三张特征图,如图中所示,蒸馏后输出的特征图不仅更好地区分了目标的前景和背景,且更细致地保留了目标的局部特征,有利于提高检测精度。

图13 蒸馏前后学生网络特征图对比Fig.13 Comparison of feature maps of student model before and after distillation

4 结束语

本文针对现有的目标检测知识蒸馏架构进行改进,提出基于信息图对一阶段目标检测器的特征提取层和特征融合层同时展开蒸馏。信息图标识了教师网络传递知识的重要性,加强了学生网络对特征层关键区域的学习;对特征融合层展开蒸馏,使学生网络在获得特征提取层知识的同时,也获得了来自教师网络特征融合后的知识。实验结果表明,改进后的知识蒸馏架构,在不改变学生网络结构的基础上,提升了更高的精度,为小模型在移动设备上部署奠定了基础。

猜你喜欢
特征提取损失特征
根据方程特征选解法
离散型随机变量的分布列与数字特征
胖胖损失了多少元
空间目标的ISAR成像及轮廓特征提取
基于Gazebo仿真环境的ORB特征提取与比对的研究
基于特征提取的绘本阅读机器人设计方案
不忠诚的四个特征
玉米抽穗前倒伏怎么办?怎么减少损失?
基于Daubechies(dbN)的飞行器音频特征提取
菜烧好了应该尽量马上吃