软件开发架构师

Graph Neural Network:GCN 算法原理,实现和应用-InfoQ

人工智能 177 2019-09-02 23:35

GCN 算法原理

首先,如果想要完整了解 GCN 的理论基础,我们还需要去了解空间域卷积,谱图卷积,傅里叶变换,Laplacian 算子这些,本文不涉及这些内容,感兴趣的同学可以自行查阅相关资料。

我们现在先记住一个结论,GCN 是谱图卷积的一阶局部近似,是一个多层的图卷积神经网络,每一个卷积层仅处理一阶邻域信息,通过叠加若干卷积层可以实现多阶邻域的信息传递。

每一个卷积层的传播规则如下:

Graph Neural Network:GCN 算法原理,实现和应用-InfoQ-1

其中:

  • ${\tilde A}=A+I_N$ 是无向图 G 的邻接矩阵加上自连接 ( 就是每个顶点和自身加一条边 ),$I_N$ 是单位矩阵。
  • ${\tilde D}$ 是 ${\tilde A}$ 的度矩阵,即 $D_{ii}=\sum_i{\tilde A}_{ij}$
  • $H^{(l)}$ 是第 $I$ 层的激活单元矩阵,$H^0=X$
  • $W^{(l)}$ 是每一层的参数矩阵

简单解释下,GCN 的每一层通过邻接矩阵 A 和特征矩阵 $H^{(l)}$ 相乘得到每个顶点邻居特征的汇总,然后再乘上一个参数矩阵 $W^{(l)}$ 加上激活函数 σ 做一次非线性变换得到聚合邻接顶点特征的矩阵 $H^{(l+1)}$。

之所以邻接矩阵 A 要加上一个单位矩阵 $I_N$,是因为我们希望在进行信息传播的时候顶点自身的特征信息也得到保留。

而对邻居矩阵 ${\tilde A}$ 进行归一化操作 ${\tilde D}^{-\frac{1}{2}}{\tilde A}{\tilde D}^{-\frac{1}{2}}$ 是为了信息传递的过程中保持特征矩阵 H 的原有分布,防止一些度数高的顶点和度数低的顶点在特征分布上产生较大的差异。

GCN 的实现

1. GCN 卷积层实现

复制代码
output = tf.matmul(tf.sparse_tensor_dense_matmul(A, features), self.kernel)
if self.bias:
output += self.bias
act = self.activation(output)

上述代码片段对应的就是

Graph Neural Network:GCN 算法原理,实现和应用-InfoQ-1

只不过多了一个偏置项。

2. GCN 的实现

复制代码
def GCN(adj_dim, num_class, feature_dim, dropout_rate=0.5, l2_reg=0, feature_less=True, ):
Adjs = [Input(shape=(None,), sparse=True)]
if feature_less:
X_in = Input(shape=(1,), )
emb = Embedding(adj_dim, feature_dim, embeddings_initializer=Identity(1.0), trainable=False)
X_emb = emb(X_in)
H = Reshape([X_emb.shape[-1]])(X_emb)
else:
X_in = Input(shape=(feature_dim,), )
H = X_in
H = GraphConvolution(16, activation='relu', dropout_rate=dropout_rate, l2_reg=l2_reg)(
[H] + Adjs)
Y = GraphConvolution(num_class, activation='softmax', dropout_rate=dropout_rate, l2_reg=0)(
[H] + Adjs)
model = Model(inputs=[X_in] + Adjs, outputs=Y)
return model

这里 feature_less 的作用是告诉模型我们是否有额外的顶点特征输入,当 feature_lessTrue 的时候,我们直接输入一个单位矩阵作为特征矩阵,相当于对每个顶点进行了 onehot 表示。

GCN 应用

本例中的训练,评测和可视化的完整代码在下面的 git 仓库中,后面还会陆续更新一些其他 GNN 算法:

https://github.com/shenweichen/GraphNeuralNetwork

使用论文引用网络数据集 Cora 进行测试,Cora 数据集包含 2708 个顶点, 5429 条边, 每个顶点包含 1433 个特征,共有 7 个类别。

按照论文的设置,从每个类别中选取 20 个共 140 个顶点作为训练,500 个顶点作为验证集合,1000 个顶点作为测试集。DeepWalk 从全体顶点集合中进行采样,最后使用同样的 140 个顶点训练一个 LR 模型进行分类。

顶点分类任务结果

Graph Neural Network:GCN 算法原理,实现和应用-InfoQ-3

从分类任务结果可以看到,在使用较少训练样本的条件下 GCN 的效果是高于 DeepWalk 的,而不含顶点特征的 GCN 的效果则会变差很多。

不含顶点特征的 GCN 相当于仅仅在学习图的拓扑结构,而对于图的拓扑结构的学习 Graph Embedding 方法也能做到,这也说明了 GCN 的优势在于能够同时融入了图的拓扑结构和顶点的特征进行学习。

顶点向量可视化

从对得到的顶点向量的可视化结果来看,GCN 得到的向量相比于 DeepWalk 产出的向量确实更加能够将同类的顶点聚集在一起,不同类的顶点区分开来。

1. DeepWalk 可视化

Graph Neural Network:GCN 算法原理,实现和应用-InfoQ-4

2. GCN 可视化

Graph Neural Network:GCN 算法原理,实现和应用-InfoQ-5

参考资料

Kipf T N, Welling M. Semi-supervised classification with graph convolutional networks[J]. arXiv preprint arXiv:1609.02907, 2016.

https://arxiv.org/pdf/1609.02907.pdf

作者介绍

沈伟臣,阿里巴巴算法工程师,硕士毕业于浙江大学计算机学院。对机器学习,强化学习技术及其在推荐系统领域内的应用具有浓厚兴趣。

本文来自 DataFun 社区

原文链接

https://mp.weixin.qq.com/s?__biz=MzU1NTMyOTI4Mw==&mid=2247493066&idx=1&sn=9e776b0e6661d052cce62e9136f39510&chksm=fbd757a6cca0deb0269c4b99ccf5c09763a0064db074c19aba1f028038cd31b88f0050fe2f0a&scene=27#wechat_redirect

文章评论