一、环境设置
我们使用deeplearning4j (dl4j) 库创建一个简单的神经网络,该库是一种现代且强大的机器学习工具。
让我们将所需的库添加到我们的 Maven pom.xml文件中,如下配置。
<properties>
<dl4j-master.version>1.0.0-M2</dl4j-master.version>
<java.version>1.8</java.version>
</properties>
<dependencies>
<!-- deeplearning4j-core: contains main functionality and neural networks -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
</dependencies>
其中,nd4j-native-platform依赖项是可选的。
它依赖于可用于许多不同平台(macOS、Windows、Linux、Android 等)的本机库。如果我们想在支持 CUDA 的显卡上执行计算,我们也可以将后端切换到诸如nd4j-cuda-8.0-platform。
二、准备数据
我们使用鸢尾花数据集的分类。这是一组从不同物种的花中收集的数据(Iris setosa、Iris versicolor和Iris virginica)。
数据集下载地址。
这些物种的花瓣和萼片的长度和宽度不同。很难编写一个精确的算法来对输入数据项进行分类(即确定一朵特定的花属于什么物种)。但是一个训练有素的神经网络可以快速分类它并且几乎没有错误。
我们将使用此数据的 CSV 版本,其中第 0 - 3 列包含物种的不同特征,第 4 列包含记录的类别或物种,用值 0、1 或 2 编码(整理数据成为如下格式):
1、读取和打乱数据
这里使用datavec库来执行此操作。
创建CSVRecordReader时,我们可以指定要跳过的行数(例如,如果文件有标题行)和分隔符(在我们的例子中是逗号)。
我们使用DataSetIterator接口的多种实现中的任何一种进行数据的遍历。因为数据集可能非常庞大,分页或缓存会派上用场。
不过鸢尾花数据集只包含 150 条记录,所以让我们通过调用iterator.next()一次将所有数据读入内存。
我们还指定了类列的索引,在我们的例子中它与特征计数 (4列特征)和类的总数(一共三类花) 相同。
另外,我们打乱数据集以避免有规律的排序。
我们指定一个常量随机种子 (42) 而不是默认的System.currentTimeMillis()调用,这样打乱之后的结果总是相同的,以便准确测试模型的能力,否则每次数据都不同也不好解释。
private static final int CLASSES_COUNT = 3;
private static final int FEATURES_COUNT = 4;
RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new FileSplit(new File("C:\\Users\\zyh\\Desktop\\iris.csv")));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
DataSet allData = iterator.next();
allData.shuffle(42);
2、规范化和拆分数据集
在训练之前我们应该对数据做的另一件事是对其进行标准化。
标准化是一个两阶段的过程:
(1)收集有关数据的一些统计信息(拟合)
(2)以某种方式更改(转换)数据以使其统一
对于不同类型的数据,标准化可能会有所不同。
例如,如果我们要处理各种尺寸的图像,我们应该首先收集尺寸统计信息,然后将图像缩放到统一的尺寸。
但是对于数字,归一化通常意味着将它们转换为所谓的正态分布。
这里使用NormalizerStandardize类完成归一化。
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);
然后将数据集拆分为训练集和测试集。
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
3、进行模型的配置
我们使用链式的方式配置我们的网络。
(1)activation(),我们选择双曲正切 (tanh) 函数做为激活函数。
(2)weightInit ()方法指定了为网络设置初始权重的多种方法之一,这里我们使用高斯分布 ( WeightInit.XAVIER )。
(3)学习率是一个重要的参数,它深刻地影响着网络的学习能力,我们这里使用new Nesterovs(0.1, 0.9),学习率为0.1,动量为0.9。
(4)另外我们指定了new L2Regularization(0.0001)设置 l2 正则化,正则化用来“惩罚”网络权重过大并防止过度拟合。
(5)接下来,我们创建一个密集(也称为完全连接)层的网络。第一层应包含与训练数据中的列相同数量的节点 (4)。第二个密集层将包含三个节点。这是我们可以改变的值,但前一层的输出数量必须相同。最终输出层应包含与类数(3)匹配的节点数。网络结构如图所示:
(6)最后设置了反向传播(最有效的训练方法之一)的类型backpropType(BackpropType.Standard)。另一种选择是BackpropType.TruncatedBPTT,通常用在循环神经网络RNN中。
List<Regularization> regularization = new ArrayList<>();
regularization.add(new L2Regularization(0.0001));
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.1, 0.9))
.regularization(regularization)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(CLASSES_COUNT).build())
.backpropType(BackpropType.Standard)
.build();
4、进行训练并保存模型
这里我们使用MultiLayerNetwork进行多层网络对象的创建,然后调用init方法进行初始化。
之后我们设置一个监听,指示按每多少Iteration打印的分数。
之后循环进行网络训练,这里为什么要循环进行训练?是因为我们使用的是DataSet,如果使用的是DataSetIterator,则写法可以是model.fit(DataSetIterator, 1000),就不需要循环;
最后我们保存训练好的模型到压缩文件,其中会包含这样的文件。
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < 10; i++) {
model.fit(trainingData);
}
model.save(new File("C:\\Users\\zyh\\Desktop\\iris.model.zip"));
5、进行验证
INDArray output = model.output(testData.getFeatures());
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
三、代码以及Log输出
1、完整代码
public class MultiLayerWithIris {
private static final int CLASSES_COUNT = 3;
private static final int FEATURES_COUNT = 4;
public static void train() throws IOException, InterruptedException {
DataSet allData;
RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new FileSplit(new File("D:\\iris.csv")));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
allData = iterator.next();
allData.shuffle(42);
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
List<Regularization> regularization = new ArrayList<>();
regularization.add(new L2Regularization(0.0001));
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.1, 0.9))
.regularization(regularization)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(CLASSES_COUNT).build())
.backpropType(BackpropType.Standard)
.build();
//configuration.setIterationCount(1);
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < 10; i++) {
model.fit(trainingData);
}
model.save(new File("D:\\iris.model.zip"));
INDArray output = model.output(testData.getFeatures());
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
}
}