0
点赞
收藏
分享

微信扫一扫

PyGCN 源码阅读

PyGCN 源码阅读

代码地址: ​​https://github.com/tkipf/pygcn​​

GCN 论文 ​​Semi-Supervised Classification with Graph Convolutional Networks​​

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 ​​PoorMemory-机器学习​​, 以后文章也会发在知乎专栏中;

cora 数据介绍

该数据集包含有关 Machine Learing 主题的 paper, 主要由 ​​cora.content​​​ 以及 ​​cora.sites​​​ 两个文件组成, 其中 ​​cora.content​​ 的格式为:

The .content file contains descriptions of the papers in the following format:

+

第一行是 <paper_id>, 总共有 2708 篇 paper

最后一行是该 paper 属于哪个领域, 总共有 7 个领域

从这些 paper 取出 unique words, 总共 1433 个,

中间的 ​​<word_attributes>​​ 长度为 1433, 每个位置上的值为 0/1, 表示 words 之间的连接, 该属性表示每篇 paper 的特征.

而 ​​cora.sites​​ 的格式为:

The .cites file contains the citation graph of the corpus. Each line describes a link in the following format:

对于 ​​paper_id1 paper_id2​​​, 含义为 ​​paper2​​​ 引用了 ​​paper1​​.

cora 图数据加载

  • 获取节点特征, 并进行归一化
  • 获取 labels, 并进行 one-hot encode
  • 获取邻接矩阵, 对称化, 归一化
  • 转换为 torch Tensor. (尤其是 scipy sparse matrix 转换为 torch sparse tensor)

图数据加载的代码位于 utils.py 中, 其中 ​​load_data​​​ 这个函数用于 cora 数据集的加载, 分别读取 ​​cora.content​​​ 文件获得节点 id, features 与 labels, 以及读取 ​​cora.sites​​​ 来获得边的信息来构建 paper 之间的引用关系图. 获取图的邻接矩阵 ​​adj​​ 代码写的很简洁:

# 读取 cora.sites 文件, 获取 paper_id, 利用 idx_map 将 paper_id 转换为 Graph 中的节点 id.
edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
dtype=np.int32)
edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
dtype=np.int32).reshape(edges_unordered.shape)
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
shape=(labels.shape[0], labels.shape[0]),
dtype=np.float32)

其中 ​​edges[:, 0]​​​ 表示 ​​paper_id1​​​, ​​edges[:, 1]​​​ 表示 ​​paper_id2​​​. 利用 ​​sp.coo_matrix​​​ 来构建 “COOrdinate” 类型的稀疏矩阵. 如果两个节点之间有连接, 相应的位于 ​​(edges[:, 0], edges[:, 1])​​ 处的值就是 1.

邻接矩阵预处理

GCN 论文 ​​Semi-Supervised Classification with Graph Convolutional Networks​​​ 中对邻接矩阵 PyGCN 源码阅读_图嵌入 的处理是, 先加上 self-connections, 之后再用度矩阵 PyGCN 源码阅读_图嵌入_02PyGCN 源码阅读_Graph Embedding_03 进行归一化:
PyGCN 源码阅读_Graph Embedding_04
其中 PyGCN 源码阅读_PyTorch_05, PyGCN 源码阅读_PyTorch_06, 注意 PyGCN 源码阅读_图嵌入无向图 PyGCN 源码阅读_PyTorch_08 的邻接矩阵.

代码实现中, 由于 ​​cora.sites​​​ 读取完得到的是有向图, 故作者先将有向图的 ​​adj​​ 邻接矩阵转换为对称的无向图邻接矩阵:

# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

该代码写成公式如下:

PyGCN 源码阅读_PyTorch_09

首先明确 PyGCN 源码阅读_PyTorch_10 的意思: PyGCN 源码阅读_GCN_11 表示节点 i 指向节点 j 的权重, 那么 PyGCN 源码阅读_PyTorch_12 就表示节点 j 指向 i 的权重.

那么 PyGCN 源码阅读_PyTorch_09 就能将得到无向图的权重 (因为 PyGCN 源码阅读_PyTorch_14 中只有 0 值和正数, PyGCN 源码阅读_GCN_15 将获取所有节点 i 指向节点 j 但是节点 j 却不指向节点 i 的位置, PyGCN 源码阅读_PyTorch_16 将得到节点 j 指向节点 i 的权重, 由于 PyGCN 源码阅读_PyTorch_14 中只保留了节点 i 指向 j 的权重, 那么两者相加就可以获得无向图 PyGCN 源码阅读_Graph Embedding_18 的权重).

(2020.04.22 补充: 这里再补充点细节. 首先要明确, 上面公式的目的是为了将有向图的邻接矩阵 PyGCN 源码阅读_图嵌入 转换为无向图的邻接矩阵. 原本情况可以很简单, 如果 PyGCN 源码阅读_图嵌入 是对角线全为 0 的上三角矩阵, 那么直接使用 PyGCN 源码阅读_图嵌入_21 就能得到无向图的邻接矩阵. 但意外情况是, 有向图可能存在 PyGCN 源码阅读_PyTorch_22PyGCN 源码阅读_GCN_23 的权重不相等的状况, 上面公式的作用是将 PyGCN 源码阅读_PyTorch_22PyGCN 源码阅读_机器学习_25 中权重最大的那个, 作为无向图的节点 PyGCN 源码阅读_图嵌入_26 与节点 PyGCN 源码阅读_图嵌入_27 的边权. 下面具体来看:

对于 PyGCN 源码阅读_PyTorch_09 公式, 它可以分为两个部分: PyGCN 源码阅读_PyTorch_29PyGCN 源码阅读_GCN_30, 先讨论第一部分. 举个例子, 邻接矩阵 PyGCN 源码阅读_图嵌入 的形式如下:

PyGCN 源码阅读_图嵌入_32

这表示节点 PyGCN 源码阅读_GCN_33 的权重为 PyGCN 源码阅读_PyTorch_34, 节点 PyGCN 源码阅读_PyTorch_35 的权重为 PyGCN 源码阅读_Graph Embedding_36, 节点 PyGCN 源码阅读_机器学习_37 的权重为 PyGCN 源码阅读_机器学习_38, 节点 PyGCN 源码阅读_图嵌入_39 的权重为 PyGCN 源码阅读_图嵌入_40. 注意到节点 PyGCN 源码阅读_PyTorch_35 的权重(为 PyGCN 源码阅读_Graph Embedding_42)和节点 PyGCN 源码阅读_图嵌入_39 的权重(PyGCN 源码阅读_图嵌入_44)不一致. 现在对 PyGCN 源码阅读_图嵌入 进行转置, 得到的结果是:

PyGCN 源码阅读_Graph Embedding_46

可以看到, 通过转置, 我们可以得到节点 PyGCN 源码阅读_GCN_23 的权重, 同时 PyGCN 源码阅读_PyTorch_48 的结果是:

PyGCN 源码阅读_GCN_49

进一步得到 PyGCN 源码阅读_GCN_50 的结果为:

PyGCN 源码阅读_PyTorch_51

最后算出 PyGCN 源码阅读_PyTorch_29 的结果为:

PyGCN 源码阅读_机器学习_53

可以看到, 此时 PyGCN 源码阅读_PyTorch_29 不完全对称, 但问题只是出在节点 PyGCN 源码阅读_PyTorch_35 的权重(为 PyGCN 源码阅读_Graph Embedding_42)和节点 PyGCN 源码阅读_图嵌入_39 的权重(PyGCN 源码阅读_图嵌入_44)不一致的情况上, 此时 PyGCN 源码阅读_PyTorch_35 的权重为 PyGCN 源码阅读_GCN_60, 是 PyGCN 源码阅读_PyTorch_61PyGCN 源码阅读_PyTorch_62 之和. 为了处理这种情况, 需要用到下面介绍的 PyGCN 源码阅读_GCN_30 这一项来进行修正, 只保留 PyGCN 源码阅读_Graph Embedding_64, 即两者中的最大值.

由于 PyGCN 源码阅读_GCN_30 的结果为:

PyGCN 源码阅读_图嵌入_66

于是 PyGCN 源码阅读_GCN_67 的结果为:

PyGCN 源码阅读_Graph Embedding_68

为对称的矩阵. 此时原有向图的邻接矩阵 PyGCN 源码阅读_图嵌入 被转换为了无向图的矩阵. )

其中 PyGCN 源码阅读_GCN_30 这一项, 作者在 issue ​​​build symmetric adjacency matrix​​​ 中有谈到, 是为了处理其他情形, 比如当 i 指向 j 与 j 指向 i 的权重不一致时, 可以保留最大的权重:
PyGCN 源码阅读_图嵌入_71
那么 PyGCN 源码阅读_PyTorch_09 的结果为:
PyGCN 源码阅读_Graph Embedding_73
结果是对称的, 并且保留了其中权重最大的边.

构建完对称矩阵之后, 下一步是创建 PyGCN 源码阅读_Graph Embedding_74:

adj = normalize(adj + sp.eye(adj.shape[0]))

其中 ​​normalize​​ 定义为:

def normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx

从这个代码来看, 作者实现的效果应该是: PyGCN 源码阅读_Graph Embedding_75.

作者在 issue ​​What is the difference between two adjacency matrix normalization?​​​ 说到可以将这个公式也看成是超参数, 在训练时可以尝试使用 PyGCN 源码阅读_Graph Embedding_75 以及 PyGCN 源码阅读_Graph Embedding_74.

作者实现的 TensorFlow 版本中 ​​https://github.com/tkipf/gcn​​​, 在 ​​https://github.com/tkipf/gcn/blob/master/gcn/utils.py​​​ 的 ​​normalize_adj​​​ 函数中实现了 PyGCN 源码阅读_Graph Embedding_74:

def normalize_adj(adj):
"""Symmetrically normalize adjacency matrix."""
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

实际上, 在 ​​如何理解 Graph Convolutional Networks? (GCN)​​ 一文的高赞回答中说到, 常用的拉普拉斯矩阵实际上有三种:

  • Combinatorial Laplacian
  • PyGCN 源码阅读_PyTorch_79 被称为对称归一化拉普拉斯算子 Symmetric normalized Laplacian
  • PyGCN 源码阅读_PyTorch_80 被称为随机游走归一化拉普拉斯算子 Random walk normalized Laplacian

Scipy 中的 sparse matrix 转换为 PyTorch 中的 sparse matrix

可以复用的代码:

构建 sparse tensor, 一般需要 coordinate indices, values 以及 shape 等信息.

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)

模型定义

GCN 定义如下:

class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()

self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
self.dropout = dropout

def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
return F.log_softmax(x, dim=1)

发现代码其实很简单, 把 ​​GraphConvolution​​​ 想象成 ​​nn.Linear​​. 上面是两层的 GCN 模型定义.

而 ​​GraphConvolution​​ 层定义如下:

class GraphConvolution(Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""

def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
# ....

def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output

其输入为节点的特征 input 以及归一化的邻接矩阵 PyGCN 源码阅读_Graph Embedding_75. 该实现和如下公式是一致的:
PyGCN 源码阅读_PyTorch_82
另外, ​​​torch.spmm​​ 是稀疏矩阵的乘法.

如果节点的特征 input 不存在的话, 可以考虑将节点的 one-hot 表示作为特征输入到模型中.

模型训练

由于 GCN 模型 ​​forward​​​ 方法最后的输出为 ​​F.log_softmax(x, dim=1)​​, 因此, 在设置 Loss 时, 使用的是:

F.nll_loss

方法.

Embedding 获取

相关代码作者没有提供, 在 issue 中谈到这个问题: ​​Could I get the node embedding?​​

可以将隐藏层的 embedding 提取出来, 另外, 由于监督模型学出来的 embedding 非常 task-specialized. 作者还推荐他的另一个库: GAE

​​https://github.com/tkipf/gae​​

其他

其他可以参考学习, 复用的内容

准确率

可以复用.

def accuracy(output, labels):
preds = output.max(1)[1].type_as(labels)
correct = preds.eq(labels).double()
correct = correct.sum()
return correct / len(labels)

注意 ​​output.max(1)​​​ 的返回结果. ​​output.max(1)[1]​​ 表示最大值所在的 indice. 比如:

>>> a = torch.arange(6).view(3, 2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5]])
>>> a.max(1)
torch.return_types.max(
values=tensor([1, 3, 5]),
indices=tensor([1, 1, 1]))
>>> a.max(1)[1]
tensor([1, 1, 1])

举报

相关推荐

0 条评论