0
点赞
收藏
分享

微信扫一扫

机器学习笔记 - 使用Deeplearning4j和浅层神经网络进行分类

一、环境设置

        我们使用deeplearning4j (dl4j) 库创建一个简单的神经网络,该库是一种现代且强大的机器学习工具。

        让我们将所需的库添加到我们的 Maven pom.xml文件中,如下配置。

<properties>
    <dl4j-master.version>1.0.0-M2</dl4j-master.version>
    <java.version>1.8</java.version>
</properties>

<dependencies>
		<!-- deeplearning4j-core: contains main functionality and neural networks -->
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j-master.version}</version>
		</dependency>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${dl4j-master.version}</version>
		</dependency>
</dependencies>

        其中,nd4j-native-platform依赖项是可选的。

        它依赖于可用于许多不同平台(macOS、Windows、Linux、Android 等)的本机库。如果我们想在支持 CUDA 的显卡上执行计算,我们也可以将后端切换到诸如nd4j-cuda-8.0-platform。

二、准备数据

        我们使用鸢尾花数据集的分类。这是一组从不同物种的花中收集的数据(Iris setosa、Iris versicolor和Iris virginica)。

        数据集下载地址。

        这些物种的花瓣和萼片的长度和宽度不同。很难编写一个精确的算法来对输入数据项进行分类(即确定一朵特定的花属于什么物种)。但是一个训练有素的神经网络可以快速分类它并且几乎没有错误。

        我们将使用此数据的 CSV 版本,其中第 0 - 3 列包含物种的不同特征,第 4 列包含记录的类别或物种,用值 0、1 或 2 编码(整理数据成为如下格式):

1、读取和打乱数据

        这里使用datavec库来执行此操作。

        创建CSVRecordReader时,我们可以指定要跳过的行数(例如,如果文件有标题行)和分隔符(在我们的例子中是逗号)。

        我们使用DataSetIterator接口的多种实现中的任何一种进行数据的遍历。因为数据集可能非常庞大,分页或缓存会派上用场。

        不过鸢尾花数据集只包含 150 条记录,所以让我们通过调用iterator.next()一次将所有数据读入内存。

        我们还指定了类列的索引,在我们的例子中它与特征计数 (4列特征)和类的总数(一共三类花) 相同。

        另外,我们打乱数据集以避免有规律的排序。

        我们指定一个常量随机种子 (42) 而不是默认的System.currentTimeMillis()调用,这样打乱之后的结果总是相同的,以便准确测试模型的能力,否则每次数据都不同也不好解释。

private static final int CLASSES_COUNT = 3;
private static final int FEATURES_COUNT = 4;


RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new FileSplit(new File("C:\\Users\\zyh\\Desktop\\iris.csv")));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
DataSet allData = iterator.next();
allData.shuffle(42);

2、规范化和拆分数据集

        在训练之前我们应该对数据做的另一件事是对其进行标准化。

        标准化是一个两阶段的过程:

        (1)收集有关数据的一些统计信息(拟合)

        (2)以某种方式更改(转换)数据以使其统一

        对于不同类型的数据,标准化可能会有所不同。

        例如,如果我们要处理各种尺寸的图像,我们应该首先收集尺寸统计信息,然后将图像缩放到统一的尺寸。

        但是对于数字,归一化通常意味着将它们转换为所谓的正态分布。

        这里使用NormalizerStandardize类完成归一化。

DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);

        然后将数据集拆分为训练集和测试集。

SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

3、进行模型的配置

        我们使用链式的方式配置我们的网络。

        (1)activation(),我们选择双曲正切 (tanh) 函数做为激活函数。

        (2)weightInit ()方法指定了为网络设置初始权重的多种方法之一,这里我们使用高斯分布 ( WeightInit.XAVIER )。

        (3)学习率是一个重要的参数,它深刻地影响着网络的学习能力,我们这里使用new Nesterovs(0.1, 0.9),学习率为0.1,动量为0.9。

        (4)另外我们指定了new L2Regularization(0.0001)设置 l2 正则化,正则化用来“惩罚”网络权重过大并防止过度拟合。

        (5)接下来,我们创建一个密集(也称为完全连接)层的网络。第一层应包含与训练数据中的列相同数量的节点 (4)。第二个密集层将包含三个节点。这是我们可以改变的值,但前一层的输出数量必须相同。最终输出层应包含与类数(3)匹配的节点数。网络结构如图所示:

         (6)最后设置了反向传播(最有效的训练方法之一)的类型backpropType(BackpropType.Standard)。另一种选择是BackpropType.TruncatedBPTT,通常用在循环神经网络RNN中。

List<Regularization> regularization = new ArrayList<>();
regularization.add(new L2Regularization(0.0001));

        MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(0.1, 0.9))
                .regularization(regularization)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
                .layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(3).nOut(CLASSES_COUNT).build())
                .backpropType(BackpropType.Standard)
                .build();

4、进行训练并保存模型

        这里我们使用MultiLayerNetwork进行多层网络对象的创建,然后调用init方法进行初始化。

        之后我们设置一个监听,指示按每多少Iteration打印的分数。

        之后循环进行网络训练,这里为什么要循环进行训练?是因为我们使用的是DataSet,如果使用的是DataSetIterator,则写法可以是model.fit(DataSetIterator, 1000),就不需要循环;

        最后我们保存训练好的模型到压缩文件,其中会包含这样的文件。

MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.setListeners(new ScoreIterationListener(1));

for (int i = 0; i < 10; i++) {
    model.fit(trainingData);
}

model.save(new File("C:\\Users\\zyh\\Desktop\\iris.model.zip"));

5、进行验证

INDArray output = model.output(testData.getFeatures());

Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());

三、代码以及Log输出

1、完整代码

public class MultiLayerWithIris {

    private static final int CLASSES_COUNT = 3;
    private static final int FEATURES_COUNT = 4;

    public static void train() throws IOException, InterruptedException {

        DataSet allData;
        RecordReader recordReader = new CSVRecordReader(0, ',');
        recordReader.initialize(new FileSplit(new File("D:\\iris.csv")));
        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
        allData = iterator.next();
        allData.shuffle(42);

        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(allData);
        normalizer.transform(allData);

        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        List<Regularization> regularization = new ArrayList<>();
        regularization.add(new L2Regularization(0.0001));

        MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(0.1, 0.9))
                .regularization(regularization)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
                .layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(3).nOut(CLASSES_COUNT).build())
                .backpropType(BackpropType.Standard)
                .build();

        //configuration.setIterationCount(1);

        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        for (int i = 0; i < 10; i++) {
            model.fit(trainingData);
        }

        model.save(new File("D:\\iris.model.zip"));


        INDArray output = model.output(testData.getFeatures());

        Evaluation eval = new Evaluation(CLASSES_COUNT);
        eval.eval(testData.getLabels(), output);
        System.out.println(eval.stats());
    }

}

2、Log输出

举报

相关推荐

0 条评论