结合结构学习的多头密集连接图池化模型

2022-05-16 08:48叶海良曹飞龙
中国计量大学学报 2022年1期
关键词:集上卷积分类

顾 昕,叶海良,曹飞龙

(中国计量大学 理学院,浙江 杭州 310018)

现如今带有卷积层和池化层的深度学习已经在多种任务上展开研究,如图像分类[1]、自然语言处理[2]、图像边缘检测[3]等,这些任务中的数据通常在欧几里德空间中表示。然而,当下海量数据如社交网络图、生物化学结构图及学术研究中的论文引用图等,都是位于非欧几里德空间中以图数据的方式存在,不满足平移不变性,传统的卷积神经网络无法直接应用,因此需要开展对图数据卷积层和池化层的研究。

如今图数据处理任务分为两大类:节点级别任务和图级别任务。开展卷积层研究,主要是为了解决节点级别任务如节点分类[4]、链接预测[5]和推荐系统[6]等。如今,有许多将卷积运算推广到图数据上的研究,我们称为图神经网络(graph neural networks,GNNs)。这些研究可以分为两大类:谱方法和空间方法。对于谱方法,它们通常基于图傅里叶变换定义图卷积运算[4],解决了图数据不满足平移不变性所造成的卷积定义困难问题。对于空间方法,图卷积运算是直接从其邻域聚合节点表示得到[7-8]。它们比谱方法更快,并且容易推广到其他图。

设计处理图数据的池化层主要是为了解决图级别任务如图分类、图生成等,池化层可以扩大感受野,降低参数,获得输入数据的层次结构。现有的图池化算子的构建主要分为两类:节点聚类[9-10]和节点采样[11-12]。

节点聚类方法:首先,将输入的大图分成几张小子图。其次,将每个小子图中的节点聚合成一个节点,重复多次。最后,将输入图聚合成一个超级节点放到分类器,完成图分类。在节点聚类方法中,所学的邻接矩阵是软连通的,容易出现过拟合问题,需要辅助的链接预测任务[9]来稳定训练,运算复杂度较高。节点采样方法如gPool[11]、SAGPool[12]等首先学习图中节点的重要性得分,然后根据得分对节点进行排序,通过设置池化率,对图中节点进行采样实现池化操作得到图表示。它使用少量的可训练参数,并且得到较好的结果。然而在SAGPool[12]中学习节点重要性得分时只是简单的通过一层GCN,造成节点得分的学习不充分、不全面,易出现误差,同时网络层数较浅,不能提取高阶特征,以及丢弃节点过程中,容易造成一些关键点被丢弃,影响整个图结构的连通性。

对于以上问题我们做了以下工作:1) 为了更全面的学习节点重要性得分,我们提出多头网络学习节点重要性得分,并且每一头卷积权重不同,实现不同的滤波操作;2) 在每一头中我们权重共享,这样在增加网络层数提取高阶特征时不会增加额外的复杂度,并且使用了密集连接,将每一层的输出密集连接传递到下一层,同时将初始特征加到每一层中,加强了特征传递以及有效利用特征,充分且全面的学习节点重要性得分;3) 引入图结构学习模块,通过节点采样得到池化图之后,我们对池化图中的每对节点学习出一个相关度,若两节点之间相关度较高,则将其相连,形成一条新的边,进而保证整个图的连通性。

1 结合结构学习的多头密集连接图池化模型

我们主要研究基于图采样的图池化方法,而图采样就需要在学习节点重要性得分时充分全面,并且通过图采样得到池化图之后还需要考虑图结构的连通性是否受到影响,因此在这项工作中,我们提出了结合结构学习的多头密集连接图池化模型(multi-head densely connected graph pooling model combined with structure learning,MuhePool)。利用多头滤波,结合密集连接学习节点重要性得分,并且对每一头计算出的节点重要性得分进行加权求和得到图中每个节点的得分,据此进行图采样,同时设置结构学习模块,以保证采样之后图结构的连通性。

1.1 提取初始特征

整个图池化模型如图1,对于输入的图数据,需要先对其提取初始特征,我们采取大家广泛使用的初始特征提取方式[4]公式表示如下:

图1 结合结构学习的多头密集连接图池化模型Figure 1 Multi-head densely connected graph pooling model combined with structure learning

1.2 多头密集连接学习节点重要性得分

图池化操作的关键是节点采样。为了进行节点采样,我们先计算每个节点的重要性得分。通常,如果一个节点的重要性得分越低,则意味着该节点可以在池化图中丢弃,而整个图几乎没有信息丢失。如图2所示,我们建立一个多头密集连接网络学习图中节点的重要性得分,公式表示如下:

图2 使用多头密集连接学习节点重要性得分示意图Figure 2 An illustration of using a multi-head densely connected to learn the importance score of nodes

在获得节点重要性得分后,我们可以选择在池化过程中保留哪些节点。为了尽量减少图信息的丢失,我们选择保留那些重要性得分较高的节点,因为它们可以提供更多的信息,用于之后的图分类任务。具体地说,我们首先根据节点重要性得分对图中的节点重新排序,然后选择保留排名靠前的节点,公式表示如下:

idx=top-rank(Z,|r×n|),

X′=X(idx,:),

A′=A(idx,idx)。

其中,r表示池化率,top-rank(·)为根据学到的节点重要性得分选择前|r×n|个节点索引的函数。X(idx,:),A(idx,idx)表示执行行或(和)列提取,以形成池化图的节点特征矩阵和邻接矩阵。

1.3 结构学习

在这一小节中,我们将介绍我们提出的结构学习模块如何在池化图中学习一个重新定义的图结构。如图1所示,对图中节点采样之后,可能会导致图中相关度较大的节点断开连接,影响图结构的连通性,阻碍信息传递过程。

我们的目标是学习一个重新构造的图结构,首先是计算每个节点之间的相关度,由于余弦相似度可以很好的度量两个向量之间的相似度,因此我们通过余弦函数计算图中每个节点特征之间的相关度,公式表达如下:

E(p,q)=cos(X(p,:),X(q,:))+β·A′(p,q)。

其中X(p,:)和X(q,:)为节点表示矩阵的第p和q行,即节点vp和vq的表示向量,A′表示采样后的图的邻接矩阵,若节点vp和vq没有边相连则A′(p,q)=0,我们加入β·A′(p,q)项,可以使已经相连的节点之间学习到相对较大的相关度得分,进而更好的保留池化图的结构信息,β是它们之间可学习的平衡参数。

为了使相关度得分在不同节点之间易于比较,我们可以使用softmax函数在节点之间对其进行归一化,然而传统的softmax函数,会在每个节点之间都计算出一个非零的相关度,形成一个全连接的图,造成学习出的重构图带有大量的噪声;这里我们使用sparsemax(·)函数[13],该函数保留了softmax函数的归一化属性,并且具有生成稀疏分布的能力,公式如下:

S(p,q)=sparsemax(E(p,q))=
[E(p,q)-γ(E(p,q))]+。

其中[x]+=max{0,x},γ(·)是一个阈值函数它根据算法1中所示的过程返回一个阈值。因此,sparsemax(·)保留阈值以上的值,其他值将被截断为零,从而产生稀疏图结构。

算法1 阈值函数γ(·)的计算流程输入:向量Z∈Rn×1.1.将Z中元素排序:z1≥z2≥…≥zn,2.找出:φ=max1≤j≤n:zj+1j1-∑ji=1zj()>0{},输出:γ(Z)=1φ∑φi=1zj-1().

1.4 模型架构

我们设计了一种分层池化结构用于图分类任务。如图3所示,该结构是由一些相同的模块组合而成。对于每个模块我们都用一个读出层来汇总输出。

图3 图池化模型架构Figure 3 Graph pooling model architecture

读出层的公式表示如下:

最后,将所有读出层的输出求和放入线性层完成图分类任务。

2 实 验

本节首先阐述实验的具体细节,包括数据集、对比方法和实验设置的介绍。其次,展示了MuhePool的图分类结果。最后,介绍了消融实验和超参数分析以及复杂度分析。

2.1 数据集

我们使用来自生物信息学和社交网络两个领域共七种真实数据集评估我们图池化模型的性能。表1总结了七个数据集的统计数据,更多描述如下:

表1 七个图分类数据集的基本统计数据Table 1 Basic statistics of seven graphclassification datasets

PROTEINS和D&D[14]是两个蛋白质图数据集,其中节点代表氨基酸。标签表示蛋白质是酶还是非酶。NCI1和NCI109[15]是两个针对非小细胞肺癌和卵巢癌细胞系活性筛选的生物数据集,其中每个图是一种化合物,节点和边分别表示原子和化学键。COLLAB[16]是科学家合作数据集,该数据集中的每个图表示来自某一个领域的科学家的自我网络,数据集一共包含三个领域的科学家:分别是高能物理、凝聚态物理和天体物理。IMDB-BINARY[17]是一个电影合作数据集,包含1000个代表演员自我网络的图。该数据集来源于动作和浪漫类型的电影合作图。在每个图中,每一个节点表示一位演员,若两个演员合作拍摄同一部电影则他们在图上有边相连。IMDB-MULTI是IMDB-BINARY的多类别版本,它包含1 500个图,共有三类:喜剧,浪漫和科幻。

2.2 对比方法

我们将我们的方法与各种图分类模型进行比较,包括基于核的方法和基于图神经网络的方法。

基于核的方法有图核(GK)[18],深度图核(DGK)[19],最短路径核(SP)[20]和匿名游走嵌入(AWE)[21]。

基于图神经网络的方法有DIFFPool[9],gPool[11],SAGPool[12],EigenPool[22],iPool[23],具有结构学习的分层图池化(SLPool)[24],ARMA[25],用于图学习的瓦瑟斯坦嵌入(WEGL)[26],图多集池化(GMT)[27]和空间卷积神经网络(SCNN)[28]。

2.3 实验设置

我们所有的实验均在Windows 10系统,Python 3.7版本的PyTorch框架下来实现,显卡为NVIDIA RTX 2080Ti。在训练过程中随机打散数据集,然后使用80%的数据集进行训练,10%用于验证,其余用于测试。我们使用Adam优化器[29],并且设置提前停止,如果50轮之后验证损失没有改善,我们将提前停止训练。对于所有数据集,节点表示的维度设置为128。采用10倍交叉验证评估模型的性能,并将10倍交叉验证精度的平均值和标准差取作结果。学习率取0.005,权重衰减取0.000 1,批量大小取较小的值16或32。MLP由三个全连接层组成,每层中的神经元数量设置为256、128、64。

2.4 图分类实验结果

我们在图分类精度方面将我们的MuhePool与其他模型进行了比较,比较结果如表2和表3所示,最佳模型以粗体突出显示,次优模型以下划线显示。表2是4个生物信息学数据集(PROTEINS、DD、NCI1、NCI109)的比较结果,表3是3个社交网络数据集(IMDB-MULTI,IMDB-BINARY,COLLAB)的比较结果。从结果可以看出,在生物信息学数据集上我们的MuhePool在PROTEINS、DD、NCI109数据集上取得了比其他模型更好的结果。对于NCI1数据集,MuhePool是次优模型。在社交网络数据集中,IMDB-MULTI,IMDB-BINARY这两个数据集每个图中的节点较少,我们的MuhePool在这两个较小的数据集上达到最优的结果,可以证明我们模型具有良好的泛化性。

表2 本文方法与对比方法在生物信息数据集上的比较结果Table 2 Comparison results between our method and comparison methods on biological information datasets

表3 本文方法与对比方法在社交网络信息数据集上的比较结果Table 3 Comparison results between our methodand comparison methods on the socialnetwork datasets

2.5 消融实验

对节点重要性得分学习方式的分析。我们对传统的利用GCN学习节点重要性得分与我们利用多头密集连接(Multi-head densely connected,Muhe)学习节点重要性得分在多个图分类数据集上进行对比实验,具体实验结果见表4,结果表明我们的多头密集连接学习出的节点重要性得分更全面,准确,图分类实验的精度更高。

对结构学习的分析。为了证明结构学习(Structure learning,SL)的作用,我们在图分类数据集上分别训练了没有结构学习的模型和带有结构学习的模型。实验结果见表4,结果表明我们在节点采样之后加上结构学习,可以有效地将由节点采样丢失的图结构信息进行一定的恢复,保证图结构的连通性,图分类实验的准确率更高。

表4 不同方式学习节点重要性得分的MuhePool和无结构学习的MuhePool在图分类数据集上的消融实验结果Table 4 Ablation experiment results of MuhePool withdifferent methods of learning node importancescore and MuhePool without structure learningon graph classification datasets

2.6 超参数讨论

我们进一步研究了几个关键超参数取不同值时对实验效果的影响。具体来说,我们选择三个有代表性并且大小不同的数据集(DD,IMDB-MULTI和PROTEINS),在这三个数据集上研究了学习节点重要性得分时滤波器头数K,网络层数T以及节点采样时池化率r对图分类性能的影响,实验结果如图4。

从图4中的实验结果可以看出对于池化率r,我们发现所有数据集的池化率在取到较大的值0.8或0.9时才能达到较好的实验结果,这也就是说明节点采样时丢弃的节点不能太多,丢弃太多易造成图信息的大量丢失。

图4 超参数分析Figure 4 Hyperparameter analysis

对于滤波器头数K和网络层数T这两个参数,在节点数较少的数据集IMDB-MULTI上K=2,T=3即取较小的值,就达到较好的实验结果,而对于DD及PROTEINS这些单个图上节点数较多的数据集,我们发现K和T这两个参数要取到较大的值4或者5时图分类的实验效果较好,即增多滤波器头数K,加深网络层数T才能充分,全面的学习那些较大的数据集中节点的重要性得分,进而更好的进行图采样,完成图池化。

2.7 复杂度分析

为了进一步验证所提出模型的有效性,我们比较了不同模型的复杂度,如表5。可以看出,我们提出的MuhePool的参数量为7.68×104,达到第二低的水平。SAGPool的参数量最低,但是,该方法的图分类性能不如我们的模型。这表明MuhePool复杂度较低,并且图分类效果更优。

表5 不同模型的参数量的比较结果Table 5 Comparison results of parameters of different models

3 结 语

针对现有的图分类模型在学习节点重要性得分时只是单头学习,并且层数较浅节点重要性得分学习不充分,以及节点采样之后图结构的连通性易受到影响这两个问题。本文提出了一种带有结构学习的多头密集连接图池化模型MuhePool。该模型在学习图中节点重要性得分时,使用多头滤波,从而可以实现不同的滤波操作。在每一头中滤波器的权重共享且实现增加网络层数提取高阶特征时,不会增加复杂度,并且使用了密集连接,将每一层的输出密集连接传递到下一层,加强了特征传递以及有效利用特征,充分且全面的学习节点重要性得分。针对节点采样之后,可能会造成关键点的丢失,整个图结构的连通性受到影响这一问题,我们设置结构学习模块,重新学习出一个图结构,以保证图结构的连通性。在多个数据集上的图分类实验结果证明了我们提出的MuhePool的先进性。

猜你喜欢
集上卷积分类
关于短文本匹配的泛化性和迁移性的研究分析
基于3D-Winograd的快速卷积算法设计及FPGA实现
分类算一算
基于互信息的多级特征选择算法
卷积神经网络的分析与设计
从滤波器理解卷积
基于傅里叶域卷积表示的目标跟踪算法
教你一招:数的分类
说说分类那些事
师如明灯,清凉温润