输入
features:[batch_size*len(j),D_e]
edge_index:[2,num_edges]
edge_norm:[num_edges]
edge_type:[num_edges]
seq_lengths:[batch_size]
umask:[batch_size,num_seqs]
输出:
log_prob:[batch_size*len(j),num_classes]
代码
import torch
from torch_geometric.nn import RGCNConv, GraphConv
from model import MatchingAttention,classify_node_features
import torch.nn as nn
import random
class GraphNetwork(torch.nn.Module):
def __init__(self,num_features,num_classes,num_relations,max_seq_len,hidden_size = 64,dropout = 0.5,no_cuda = False):
super(GraphNetwork,self).__init__()
self.conv1 = RGCNConv(num_features,hidden_size,num_relations,num_bases = 30)
self.conv2 = GraphConv(hidden_size,hidden_size)
self.matchatt = MatchingAttention(num_features + hidden_size,num_features + hidden_size,att_type = 'general2')
self.linear = nn.Linear(num_features + hidden_size,hidden_size)
self.linear_beta = nn.Linear(num_features + hidden_size,num_features + hidden_size)
self.dropout = nn.Dropout(dropout)
self.smax_fc = nn.Linear(hidden_size,num_classes)
self.no_cuda = no_cuda
def forward(self,x,edge_index,edge_norm,edge_type,seq_lengths,umask):
# x:[batch_size*len(j),D_e]
# edge_index:[2,num_edges]
# edge_norm:[num_edges]
# edge_type:[num_edges]
# seq_length:[batch_size]
# umask:[batch_size,num_seqs]
# out:[batch_size*len(j),hidden_size]
out = self.conv1(x,edge_index,edge_type)
# out:[batch_size*len(j),hidden_size]
out = self.conv2(out,edge_index)
# emotions:[batch_size*len(j),hidden_size+D_e]
emotions = torch.cat([x,out],dim = -1)
# classify_node_features
# 输入:
# emotions:[batch_size*len(j),hidden_size+D_e]
# self.linear:linear(num_features + hidden_size,hidden_size)
# self.linear_beta:linear(num_features + hidden_size,num_features + hidden_size)
# self.smax_fc:linear(hidden_size,num_classes)
# 输出:
# log_prob:[batch_size*len(j),num_classes]]
log_prob = classify_node_features(emotions,self.linear,self.linear_beta,self.dropout,self.smax_fc,self.no_cuda)
return log_prob
num_features = 200
num_classes = 7
num_relations = 8
max_seq_len = 18
num_edges = 250
batch_size = 7
num_seqs = 30
num_nodes = 53
model = GraphNetwork(num_features,num_classes,num_relations,max_seq_len)
x = torch.randn(num_nodes,num_features)
edge_index = torch.randint(0,num_nodes,(2,num_edges))
edge_norm = torch.randn(num_edges)
edge_type = torch.randint(0,10,(num_edges,))
seq_lengths = [10]*batch_size
umask = torch.ones([batch_size,num_seqs]).long()
print(model(x,edge_index,edge_norm,edge_type,seq_lengths,umask).shape)
输出:
torch.Size([53, 7])