0
点赞
收藏
分享

微信扫一扫

分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)


分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)


目录

  • 分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)
  • 效果一览
  • 基本介绍
  • 程序设计
  • 参考资料


效果一览

分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)_分类


分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)_多特征分类预测_02


分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)_多特征分类预测_03


分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)_Attention-LSTM_04


分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)_Attention-LSTM_05

基本介绍

分类预测 | MATLAB实现基于Attention-LSTM的多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)

程序设计

  • 完整程序和数据下载(资源处下载):Attention-LSTM的多特征数据分类预测

% 需要学习的参数
lstmweight = params.lstm.weights;
lstmrecurrentWeights = params.lstm.recurrentWeights;
lstmbias = params.lstm.bias;
% 不同批次间传递的参数(这里假设每一轮epoch中,不同Batch间的state是传递的,但不学习;
h0 = state.lstm.h0;
c0 = state.lstm.c0;
[Lstm_Y,h0,c0] = lstm(Train_X,h0,c0,lstmweight,lstmrecurrentWeights,lstmbias);

Htt = dlarray(Lstm_Y(:,:,1:end-1),'SBSC');    %转变成CNN输入格式,’SS为

%% Attention
Attentionweight  = params.attention.weight; % 计算得分权重
Att = dlarray(squeeze(sum(CnnHttAtt .* dlarray(Attentionweight,'SC'),2)),'SBC'); %'C'维度为cnn卷积后的每一行
Ht = Lstm_Y(:,:,end);       % 参考向量

HtAfter = dlarray(repmat(Ht,[1,1,50]),'SBC');
f = squeeze(sum(HtAfter.*Att,1));
socre = sigmoid(f);                   % 计算得分'CB'
socre = dlarray(repmat(socre,[1,1,6]),'CBS'); 

% 组成Vt
CnnAfterRow = dlarray(squeeze(CnnHtt),'CSB');    % 满足与socre维度一致
Vt = sum(CnnAfterRow .*socre,2);
Vt = squeeze(Vt);



%% Attention输出
weight1 = params.attenout.weight1;
bias1 = params.attenout.bias1;
weight2 = params.attenout.weight2;
bias2 = params.attenout.bias2;
Hthat = fullyconnect(Vt,weight1,bias1) + fullyconnect(Ht,weight2,bias2);

%% 全连接层前置层(降维)
LastWeight = params.fullyconnect.weight1;
LastBias = params.fullyconnect.bias1 ;
FullyconnectInput = fullyconnect(Hthat,LastWeight,LastBias);
FullyconnectInput = relu(FullyconnectInput);

举报

相关推荐

0 条评论