前言
实现
Basis decomposition module.
Basis decomposition is introduced in “Modeling Relational Data with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>
__” and can be described as below:
W o = ∑ b = 1 B a o b V b W_o = \sum_{b=1}^B a_{ob} V_b Wo=b=1∑BaobVb
Each weight output :math:W_o
is essentially a linear combination of basis transformations :math:V_b
with coefficients :math:a_{ob}
.
If is useful as a form of regularization on a large parameter matrix. Thus, the number of weight outputs is usually larger than the number of bases.
这里以 dglnn.WeightBasis
为例介绍一下整个代码。
class WeightBasis(nn.Module):
def __init__(self,
shape,
num_bases,
num_outputs):
super(WeightBasis, self).__init__()
self.shape = shape
self.num_bases = num_bases
self.num_outputs = num_outputs
if num_outputs <= num_bases:
dgl_warning('The number of weight outputs should be larger than the number'
' of bases.')
self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))
def forward(self):
r"""Forward computation
Returns
-------
weight : torch.Tensor
Composed weight tensor of shape ``(num_outputs,) + shape``
"""
# generate all weights from bases
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape)
主要思想
- 共享参数:有效防止非常见关系数据上过拟合现象的出现
- 基分解是一种非常常见的数据变换的表示方式(图信号的傅里叶变换中也体现了这种思想),这在机器学习中,是一种重要的数据处理技巧。