0
点赞
收藏
分享

微信扫一扫

小黑维度分析与逐行调试:GraphNetwork

i奇异 2022-03-11 阅读 24

输入

features:[batch_size*len(j),D_e]

edge_index:[2,num_edges]

edge_norm:[num_edges]

edge_type:[num_edges]

seq_lengths:[batch_size]

umask:[batch_size,num_seqs]

输出:

log_prob:[batch_size*len(j),num_classes]

代码

import torch
from torch_geometric.nn import RGCNConv, GraphConv
from model import MatchingAttention,classify_node_features
import torch.nn as nn
import random
class GraphNetwork(torch.nn.Module):
    def __init__(self,num_features,num_classes,num_relations,max_seq_len,hidden_size = 64,dropout = 0.5,no_cuda = False):
        super(GraphNetwork,self).__init__()
        self.conv1 = RGCNConv(num_features,hidden_size,num_relations,num_bases = 30)
        self.conv2 = GraphConv(hidden_size,hidden_size)
        self.matchatt = MatchingAttention(num_features + hidden_size,num_features + hidden_size,att_type = 'general2')
        self.linear = nn.Linear(num_features + hidden_size,hidden_size)
        self.linear_beta = nn.Linear(num_features + hidden_size,num_features + hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.smax_fc = nn.Linear(hidden_size,num_classes)
        self.no_cuda = no_cuda
        
    def forward(self,x,edge_index,edge_norm,edge_type,seq_lengths,umask):
        # x:[batch_size*len(j),D_e]
        # edge_index:[2,num_edges]
        # edge_norm:[num_edges]
        # edge_type:[num_edges]
        # seq_length:[batch_size]
        # umask:[batch_size,num_seqs]
        
        
        # out:[batch_size*len(j),hidden_size]
        out = self.conv1(x,edge_index,edge_type)
        # out:[batch_size*len(j),hidden_size]
        out = self.conv2(out,edge_index)
        # emotions:[batch_size*len(j),hidden_size+D_e]
        emotions = torch.cat([x,out],dim = -1)
        # classify_node_features
        # 输入:
        # emotions:[batch_size*len(j),hidden_size+D_e]
        # self.linear:linear(num_features + hidden_size,hidden_size)
        # self.linear_beta:linear(num_features + hidden_size,num_features + hidden_size)
        # self.smax_fc:linear(hidden_size,num_classes)
        # 输出:
        # log_prob:[batch_size*len(j),num_classes]]
        log_prob = classify_node_features(emotions,self.linear,self.linear_beta,self.dropout,self.smax_fc,self.no_cuda)
        return log_prob
    
    
num_features = 200
num_classes = 7
num_relations = 8
max_seq_len = 18
num_edges = 250
batch_size = 7
num_seqs = 30
num_nodes = 53


model = GraphNetwork(num_features,num_classes,num_relations,max_seq_len)
x = torch.randn(num_nodes,num_features)
edge_index = torch.randint(0,num_nodes,(2,num_edges))
edge_norm = torch.randn(num_edges)
edge_type = torch.randint(0,10,(num_edges,))
seq_lengths = [10]*batch_size
umask = torch.ones([batch_size,num_seqs]).long()
print(model(x,edge_index,edge_norm,edge_type,seq_lengths,umask).shape)

输出:

torch.Size([53, 7])

举报

相关推荐

0 条评论