结合对抗网络与条件均值的多源适应分类方法

2022-03-21 10:33谭茜成邹俊颖
计算机工程与设计 2022年3期
关键词:源域分类器分类

夏 青,郭 涛,谭茜成,邹俊颖

(四川师范大学 计算机科学学院,四川 成都 610101)

0 引 言

深度神经网络在各种机器学习问题和应用方面取得了重大进展。然而这一重大进展归功于大规模标记数据的可用性[1]。但人工标注数据耗时费力、代价昂贵。相比于传统的监督学习,无监督学习[2]利用没有标注的数据进行模型训练,以解决机器学习中标注数据缺乏的问题。此外,传统机器学习通常假设训练数据和测试数据来自于同一概率分布[3]。而在实际中,由于训练数据和测试数据往往来源于不同的数据分布,这就导致了在很多实用场景中不能正常使用传统机器学习算法下训练出来的模型,学习到的模型在新领域使用时性能会大幅度衰减[4]。与机器学习不同,迁移学习[5,6](transfer learning,TL)借助于源域数据训练过程中学到的知识,完成对目标域的识别[7]。但是不同域之间存在的间隙使得源域训练的模型在目标域进行识别的时候学习效果会受到影响,域适应[8]学习作为迁移学习中的一种代表性方法,通过建立从有标签源域到无标签目标域的知识迁移,学习域间共享信息,实现模型在目标域上的正确分类[9]。针对域适应的研究,SankaraNara-yanan等[10]提出的生成适应网络(generate-to-adapt,GTA),通过学习单个源和单个目标之间的共享特征嵌入和生成对抗网络[11](generative adversarial network,GAN)之间的共生关系来减小域差异,进一步利用源域中学习到的知识对目标域进行预测,但仅使用单源域学习提取到的数据特征有限,且对抗机制不足以减少域差异,当样本来自于多个不同概率分布的时候,模型会出现负迁移,使得模型的分类性能受到影响。目前大多数域适应算法和理论假设源样本仅从单个源域进行采样。而在实际应用中,会在多个不同设备上采集到源样本数据用于训练,但是这些数据不但和目标域概率分布不同,而且互相之间概率分布也不同[12]。基于此,杨强等提出了多源迁移学习[13],将一个源域扩展为多个源域,利用多个源域中丰富的监督信息能够更有效辅助目标域的学习[14]。近来,朱勇椿等研究学者提出了多特征空间适应网络[15](multiple feature spaces adaptation network,MFSAN),通过提取多个源域和目标域之间的共享特征,使用最大均值差异[16]优化每一对源和目标的距离。受MFSAN思想启发,并针对GTA模型的不足,本文提出一种结合对抗网络与条件均值的多源适应分类方法,该方法通过对特征提取网络的训练,提取多个源域和目标域之间的域不变信息。将学习到的源域和目标域特征信息送入特定域的生成对抗网络,同时使用条件最大均值差异[17]最小化域间距离,利用无监督对抗训练辅助分类网络对目标域特征进行识别。由于不同源训练的分类器具有差异性,因此使用差异度量准则对每一个分类器的输出进行约束,并回传各个类别的梯度信号,以提高网络的分类性能。在具有4个源域和一个目标域的实验环境中的实验结果表明了MSDACG模型利用多源域的监督信息来提高目标域学习的有效性,且分类精度有明显提升。

1 理论基础

1.1 单源无监督域适应

域适应由著名学者杨强提出,能够有效地解决训练样本和测试样本概率分布不一致的学习问题[18],是当前机器学习的热点研究领域。以下给出相关的定义:

定义3 单源无监督域适应[19]:首先给定一个有标记的源域数据DS=(XS,YS) 和一个无标记的目标域数据集DT=(XT), 定义源域输出函数FS:XS→YS, 即两者构成源域(DS,FS),目标域即为(DT,FT),单源域适应的目标则是通过解决单个源域和单个目标域之间分布不同的问题,将在源域中学习到的知识对目标域的输出函数FT进行学习[20]。图1显示了单个源和目标域之间的学习过程。

图1 单源->目标域

1.2 多源无监督域适应

单源无监督域适应关注的是一个域的场景,而多源无监督域适应[21]作为单源无监督域适应的一种扩展,首先给定DS={DS1,DS2,…,DSN}, 即多源无监督域适应方法假定样本是从N个不同的源域 {DS1,DS2,…,DSN} 中进行收集的,给定XS1,XS2,…,XSN是分别来自N个源域DS1,DS2,…,DSN的样本,目标域数据记为DT。多源域适应的目标旨在解决多个源域和目标域之间分布差异的问题。利用在多个源域中学习到的知识对目标域进行预测。图2显示了多个源域和目标域之间的学习过程。

图2 多源->目标域

1.3 GTA模型

GTA模型的学习过程分为两个阶段:①第一个阶段:学习一个特征提取网络并提取单个源和目标之间的共享特征嵌入作为生成器的输入,以生成类似源域的数据,而判别器作为二分类器通过分辨真实的数据和生成数据之间的真假信息,并将学到的信息进行回传;同时判别器作为多分类器仅使用源域的标签信息进行监督学习。②第二个阶段:学习一个分类器,利用源域的共享特征嵌入作为分类器的输入并且实现在目标域上的预测。图3是GTA模型的结构。

图3 GTA模型结构

2 MSDACG模型

2.1 问题描述

通过对GTA模型的结构进行深入研究发现存在以下问题:一是特征提取器作为学习源和目标的共享特征嵌入,其仅使用对抗训练的方式不足以拉近域间距离,缺少距离度量准则来约束域间距离,使得域适应效果受到影响。二是使用单源学习到的知识有限,导致难以识别来自分布不同的判别特征。三是如果特征提取器学习来自多个不同概率分布的源集合和目标域的共享嵌入特征,会使得模型产生负面影响,从而影响模型的分类效果。

2.2 结构描述

受GTA模型框架的启发,并针对其存在的不足,本文提出结合条件均值与对抗机制的多源适应分类方法。其流程如图4所示。模型首先利用特征提取器(F)提取所有域的共享嵌入表示,进而通过特定域的生成器分支(G)和特定域的判别器分支(D)学习不同源域和目标域间的特征,使用CMMD减小不同域间的条件分布差异,以辅助利用多个源域的监督信息对无标记的目标域数据进行识别。由于特定域的分类器(C)之间可能会出现差异,因此采用差异损失来约束不同分类器的输出,以使得分类器的预测尽可能一致。下面分别对MSDACG模型的4个网络结构流程及特点分别进行详细介绍:

首先设定size为多个源域和目标域输入数据的尺寸大小,N为源域的个数。

(1)共享特征提取网络(F):

3)嵌入空间E服从于标准的高斯分布,并随机产生一定维度的噪声数据Z[size]。

(2)特定域的生成器网络(G):

1)G使用反卷积神经网络结构,并且G是由 {G1,G2,…,GN}N个特定域的生成器组成。

(3)特定域的判别器网络(D):

(4)特定域的分类器网络(C):

1)C使用全连接层结构,并且C是由 {C1,C2,…,CN}N个特定域的分类器组成。

图4 MSDACG模型总体结构

2.3 评估方法

本实验采用CMMD距离度量、交叉熵损失函数,对抗性损失函数以及差异损失函数作为评估MSDACG模型的方法。

CMMD距离度量:CMMD是MMD的延伸概念。MMD是被用于计算不同数据的边缘概率分布之间的差异,而CMMD是用于计算不同数据的条件概率分布P(XS|YS=C) 和Q(XT|YT=C) 之间的差异,其中C表示样本的类别数量。由于在无监督学习中,目标域数据是没有标签的。因此,需要使用深度神经网络的输出y′=f(XT) 作为目标域上的伪标签。则CMMD的计算公式可以表示为

(1)

依据式(1),可得知CMMD在D上的损失函数Lcmmd的表达式见式(2),其中μ为动态平衡因子,用于对CMMD减小条件分布距离的程度作出衡量

(2)

通过最小化等式(2)可以有效拉近源域和目标域之间的条件概率分布。

(3)

(4)

(5)

(6)

(7)

D输出两个分布,来辨别输入图像的真伪性。根据GAN算法原理,对抗性损失函数的目的在于G利用D回传的对抗性特征信息,在不断地迭代优化过程中,G能够生产出越来越类似源域类别空间的图像,从而使得D难以分辩图像的真假,最后达到一个纳什均衡状态。

差异损失函数:差异损失函数Ldisc的作用是为了解决引入多源时产生的各分类器差异的问题。在训练过程中,分类器是由不同的源域监督信息进行训练的,因此导致在对目标域预测的时候会出现分歧,特别是类边界附近的目标样本。正确的方式是不同分类器预测相同的目标样本应该得到相同的预测。因此通过最小化所有分类器之间的距离以解决样本观测不平衡的问题。本文利用目标域数据的所有分类器的概率输出之间差异的绝对值作为差异损失,计算表达式如式(8)所示

(8)

通过最小化方程(8),所有分类器的概率输出是相似的。最后,预测目标样本的标签为计算所有分类器输出的平均值。

2.4 算法流程

MSDACG模型的整体算法流程如算法1所示。

算法1: MSDACG模型训练

Input:N个源域数据集DS={DS1=(XS1,YS1),DS2=(XS2,YS2),…,DSN=(XSN,YSN)}, 目标域数据集DT=(XT,YT), 训练迭代次数T, 批量大小size, 权衡系数λ。

Output: MSDACG模型∑

(1) 随机初始化模型∑中所有的网络层参数。

(2)fortin 1:Tdo

(5) 随机产生size个噪声数据, 记为z[size];

(7) 根据式(3)、 式(4)及式(8)计算分类器上的损失函数

LC=Lcls+Lcls,d+λLdisc;

(8) 根据式(2)和式(4)~式(7)计算判别器上的损失函数

LD=Lsrc+Lcls,d+Ladv,src+Ladv,tgt+Lcmmd;

(9) 根据式(2)、 式(4)、 式(6)计算生成器上的损失函数

LG=Lcls,d+Ladv,src+Lcmmd;

(10) 根据式(3)、 式(4)、 式(7)、 式(8)计算共享特征提取网络的损失函数

LF=Lcls+Lcls,d+Ladv,tgt+λLdisc;

(11) 使用梯度下降法进行反向传播各个网络的梯度信号;

(12)endfor

(13)输出模型∑, 算法终止。

3 实验结果与分析

3.1 数据集

实验中使用的5个数据集分别是从以下的公开数据集中进行采样,即:mt(MNIST)、mm(MNIST-M)[22]、sv(SVHN)、up(USPS)和sy(Synthetic Digits)。使用与文献[23]相同设置,实验从训练集中采样25 000幅图像用于训练,从MNIST、MINST-M、SVHN和Synthetic Digits中的测试集中采样9000幅图像用于测试。而对于USPS数据集总共仅包含9298幅图像,所以选择整个数据集作为一个域。实验中轮流选择一个域作为目标域,记为DT,其余的分别作为源域D1,D2,D3,D4。

3.2 实验配置

实验环境配置为:NVIDIA TESLA SXM2 V100 32 GB GPU服务器,Ubuntu16.04操作系统,Intel至强E5-2698v4处理器20核心,40线程。32 GB DDR4 LRDIMM 2133 MHz内存,480G Intel S3610 6 Gb/s SATA 3.0 SSD系统硬盘,平台为pytorch。模型采用小批量Adam优化器进行训练,学习率统一设置为0.0005,学习率衰减参数为0.0001,指数参数设置为β1为0.8,β2为0.999,批量大小统一设置为100。

3.3 分类精度对比实验

实验过程:

(1)根据3.1节描述选取一个数据集的训练数据作为目标域数据集DT,其余数据集中的训练数据分别作为源域数据集 {D1,D2,D3,D4}。

(2)参照算法1对多个源域与目标域进行训练,获得相应的4个模型Φ1,Φ2,Φ3,Φ4。

(3)冻结Φ1模型中的F网络和C1网络,并且记为测试模型Ω1,对Φ2,Φ3,Φ4采用一致的步骤处理,获得相应的测试模型Ω2,Ω3,Ω4。

(4)采用Ω1,Ω2,Ω3,Ω4分别对DT进行预测,计算出在每个测试模型下的分类准确率。

(5)最后取这4个分类精度的平均值作为最终分类精度。在5种不同的域适应情况下进行验证,并重复以上过程,计算出每种域适应情况下的分类精度。

结果分析:如表1所示,在5种多源域适应任务下,将MSDACG方法与当前多源域适应方法进行分类精度的比较。其中,粗体表示分类精度最高的值。可以看出MSDACG方法平均值达到了90.56%,相较于M3SDA的平均分类精度提高了2.91%。在目标域数据集为MNIST-M的迁移任务上,其精度可以达到80.86%,与M3SDA相比,其分类精度提高了8.04%;相较于DCTN提高了10.33%。而在其它多源域适应任务下,其分类精度也提高了0.35%~3.16%左右。图5展示了在MSDACG模型下, mm,mt,sv,sy→up这一组迁移任务的分类损失比较,横轴代表模型训练的迭代次数,而竖轴代表训练过程中产生的分类损失函数值。可以看出随着迭代次数的增加其分类损失呈现不断递减的趋势,且越来越接近于x坐标轴,验证了学习来自不同域的信息对分类器效果有一定的提升。

表1 MSDACG与当前主流的多源域适应方法的分类精度比较

图5 mm,mt,sv,sy→up分类损失折线

3.4 图像生成对比实验

实验过程:

(1)对4个源域 {D1,D2,D3,D4} 分别随机选取100幅图像作为测试数据集D1*,D2*,D3*,D4*。

(2)按照算法1的步骤,固定F网络以及G1,G2,G3,G4。 产生对应的模型Ψ1,Ψ2,Ψ3,Ψ4分别作为测试模型。

(3)使用测试数据D1*,D2*,D3*,D4*经过对应的模型Ψ1,Ψ2,Ψ3,Ψ4分别生成对应的生成数据。

结果分析:图6展示了在MSDACG模型下, mm,mt,sv,sy→up迁移任务的图像效果。将每一组生成图像和真实图像进行可视化对比分析,模型根据来自不同概率分布的源域样本生成了类似源域效果的图像,对于像MNIST-M、SVHN、Synthetic Digits这样具有彩色数字的图像,模型也能够根据其特点生成纹理以及边缘构造清晰的数字。而对于Synthetic Digits生成图像的效果相较于MNIST-M、SVHN不是特别好的原因可能是与原先真实图像的清晰度有关。而对于具有黑白手写数字样式的MNIST数据集来说,生成的数字图像边缘以及轮廓也具有良好的可视化效果。因此能够验证生成器是可以学习到来自不同源域的数据特征进而生成类似源域分布的生成图像,并且对于学习彩色图像也具有相对优秀的能力。图7展示了MSDACG模型在训练过程中CMMD值随迭代次数变化的趋势图,这里纵轴代表CMMD的值,表示源域特征与目标域特征的像素矩阵经过RKHS空间中使用具有衡量条件概率分布的CMMD进行计算所得出的值。即随着迭代次数的增加CMMD值在不断减小,该结果验证了随着模型的不断迭代,每一组源域和目标域之间的条件概率差异在不断地减小。

图6 mm,mt,sv,sy→up生成图像可视化

图7 mm,mt,sv,sy→up CMMD值折线

3.5 t-SNE特征嵌入可视化分析实验

实验过程:

(1)分别从3.1节中D1,D2,D3,D4,DT的每个类别里随机选取50个数据及相应的标签,分别组成总大小为500的测试数据D1**,D2**,D3**,D4**,DT**。

(2)将D1**和DT**进行数据归一化操作,且计算出对应tsne值tsneS1和tsneT。

(3)使用tsneS1和tsneT以及D1**和DT**和对应的标签绘制适应前的tsne可视化特征。其次通过固定3.3节中模型Φ1的F网络、G1网络和D1网络作为特征可视化的测试模型Γ1,在模型上使用D1**和DT**产生适应后的源域特征数据featureS1和目标域特征数据featureT。

(4)同理,D2**,D3**,D4**重复以上过程,即可得到适应后的tsne嵌入判别器最后一层卷积且经过CMMD度量方法适应后的特征可视化效果。

结果分析:图8中(a)~(d)分别展示了MSDACG模型在mm,mt,sv,sy→up这一组迁移任务情况下多个源域和目标域之间的适应前后的效果。在MNIST-M→USPS这一组迁移任务的tsne图中,左边表示的未适应前的可视化分布图,可以看出未适应前的特征分布散乱,且随机分布在空间中,域间隙较大,分类信息难以识别。而在使用MSDACG模型进行域适应之后,域间距离开始聚拢,且分类信息更加明显。在其它3组任务中,源域为SVHN的这一组任务中,适应相对较弱,通过对左图SVHN的真实数据分布进行分析,初步推断是由于真实数据分布过于散乱,且类别难以区分使得聚拢效果相比其它3组较弱。源域为MNIST和Synthetic Digits的这两组任务中,可以看出未适应前同一种颜色的数据中,域间距离较大,且存在多数类别错分的情况,在适应后之后,同一种颜色和数字标签从不同的位置开始朝着与自己具有相同的特征方向靠近。从而验证了条件概率度量准则以及对抗训练的加入对模型的域适应能力以及分类性能都有一定的提升。

图8 mm,mt,sv,sy→up下适应前后tsne可视化对比

4 结束语

为解决当前大多数域适应方法仅假设样本来自单个域的情况而未考虑到多源任务的迁移,并且针对GTA模型仅使用单源学习到的特征有限以及使用对抗训练拉近域间隙能力较弱的问题,本文提出一种结合对抗网络与条件均值的多源适应分类方法MSDACG,该方法学习多个源域和目标域之间的共享特征嵌入,并且考虑到每个源域之间不同的决策边界,使用特定域的生成器和判别器之间的对抗训练联合条件最大均值差异来减小每一组源域和目标域之间的间隙,加强类与类间的约束。同时,该模型还利用特定域的分类器网络训练来自不同源域的数据,并且对不同分类器的预测输出进行约束,从而以更优的预测能力来识别目标域中的数据。在4种源域下的实验结果表明MSDACG模型在多源域适应分类中具有良好的效果。

猜你喜欢
源域分类器分类
分类算一算
基于朴素Bayes组合的简易集成分类器①
基于参数字典的多源域自适应学习算法
基于特征选择的SVM选择性集成学习方法
分类讨论求坐标
教你一招:数的分类
基于差异性测度的遥感自适应分类器选择
说说分类那些事
从映射理论视角分析《麦田里的守望者》的成长主题
基于层次化分类器的遥感图像飞机目标检测