zoukankan      html  css  js  c++  java
  • Java机器学习框架deeplearing4j入门教程

    1.添加项目
    maven添加依赖 or 导入jar包 or 使用jvm

    <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    
    <groupId>YOURPROJECTNAME.com</groupId>
    <artifactId>YOURPROJECTNAME</artifactId>
    <version>1.0-SNAPSHOT</version>
    <packaging>jar</packaging>
    
    <name>YOURNAME</name>
    <url>http://maven.apache.org</url>
    
    <properties>
    <nd4j.backend>nd4j-native-platform</nd4j.backend>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <shadedClassifier>bin</shadedClassifier>
    <java.version>1.7</java.version>
    <nd4j.version>0.6.0</nd4j.version>
    <dl4j.version>0.6.0</dl4j.version>
    <datavec.version>0.6.0</datavec.version>
    <arbiter.version>0.6.0</arbiter.version>
    <guava.version>19.0</guava.version>
    <logback.version>1.1.7</logback.version>
    <jfreechart.version>1.0.13</jfreechart.version>
    <maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
    <exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
    <maven.minimum.version>3.3.1</maven.minimum.version>
    </properties>
    
    <dependencyManagement>
    <dependencies>
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>${nd4j.version}</version>
    </dependency>
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-7.5-platform</artifactId>
    <version>${nd4j.version}</version>
    </dependency>
    </dependencies>
    </dependencyManagement>
    
     
    
    <dependencies>
    <dependency>
    <groupId>junit</groupId>
    <artifactId>junit</artifactId>
    <version>3.8.1</version>
    <scope>test</scope>
    </dependency>
    <!-- ND4J后端。每个DL4J项目都需要一个。一般将artifactId指定为"nd4j-native-platform"或者"nd4j-cuda-7.5-platform" -->
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>${nd4j.backend}</artifactId>
    </dependency>
    
    <!-- DL4J核心功能 -->
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>${dl4j.version}</version>
    </dependency>
    
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-nlp</artifactId>
    <version>${dl4j.version}</version>
    </dependency>
    
    <!-- deeplearning4j-ui用于HistogramIterationListener + 可视化:参见http://deeplearning4j.org/cn/visualization -->
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-ui</artifactId>
    <version>${dl4j.version}</version>
    </dependency>
    
    <!-- 强制指定使用UI/HistogramIterationListener时的guava版本 -->
    <dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>${guava.version}</version>
    </dependency>
    
    <!-- datavec-data-codec:仅用于在视频处理示例中加载视频数据 -->
    <dependency>
    <artifactId>datavec-data-codec</artifactId>
    <groupId>org.datavec</groupId>
    <version>${datavec.version}</version>
    </dependency>
    
    <!-- 用于前馈/分类/MLP*和前馈/回归/RegressionMathFunctions示例 -->
    <dependency>
    <groupId>jfree</groupId>
    <artifactId>jfreechart</artifactId>
    <version>${jfreechart.version}</version>
    </dependency>
    
    <!-- Arbiter:用于超参数优化示例 -->
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>arbiter-deeplearning4j</artifactId>
    <version>${arbiter.version}</version>
    </dependency>
    </dependencies>
    </project>

    2.项目引用库

    import org.deeplearning4j.nn.multilayer._
    import org.deeplearning4j.nn.graph._
    import org.deeplearning4j.nn.conf._
    import org.deeplearning4j.nn.conf.inputs._
    import org.deeplearning4j.nn.conf.layers._
    import org.deeplearning4j.nn.weights._
    import org.deeplearning4j.optimize.listeners._
    import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator
    import org.deeplearning4j.eval.Evaluation
    
    import org.nd4j.linalg.learning.config._ // for different updaters like Adam, Nesterovs, etc.
    import org.nd4j.linalg.activations.Activation // defines different activation functions like RELU, SOFTMAX, etc.
    import org.nd4j.linalg.lossfunctions.LossFunctions // mean squared error, multiclass cross entropy, etc.

    3.准备加载数据
    dl4j有数据迭代器。帮助批处理和迭代数据集。Deeplearning4j带有一个内置的BaseDatasetIteratorEMNIST 实现,
    称为EmnistDataSetIterator。这个特殊的迭代器是一个便利实用程序,用于处理数据的下载和准备。
    可以创建多个数据迭代器,用于训练模型或者评估模型等。
    创建迭代器代码

    import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator //引入数据迭代器库
    
    val batchSize = 16 // how many examples to simultaneously train in the network //数据集大小
    val emnistSet = EmnistDataSetIterator.Set.BALANCED
    val emnistTrain = new EmnistDataSetIterator(emnistSet, batchSize, true) //实例化训练迭代器
    val emnistTest = new EmnistDataSetIterator(emnistSet, batchSize, false) //实例化评估迭代器

    4.建立神经网络
    在dl4j中使用的任何与神经网络有关的操作是在NeuralNetConfiguration类中的。可在此处配置超参数和算法的学习方式。

    val outputNum = EmnistDataSetIterator.numLabels(emnistSet) // total output classes
    val rngSeed = 123 // integer for reproducability of a random number generator
    val numRows = 28 // number of "pixel rows" in an mnist digit
    val numColumns = 28
    
    val conf = new NeuralNetConfiguration.Builder()
    .seed(rngSeed)
    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    .updater(new Adam())
    .l2(1e-4)
    .list()
    .layer(new DenseLayer.Builder()
    .nIn(numRows * numColumns) // Number of input datapoints.
    .nOut(1000) // Number of output datapoints.
    .activation(Activation.RELU) // Activation function.
    .weightInit(WeightInit.XAVIER) // Weight initialization.
    .build())
    .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
    .nIn(1000)
    .nOut(outputNum)
    .activation(Activation.SOFTMAX)
    .weightInit(WeightInit.XAVIER)
    .build())
    .pretrain(false).backprop(true)
    .build()


    5.训练模型
    现在我们已经构建了一个NeuralNetConfiguration,我们可以使用配置来实例化一个MultiLayerNetwork。当我们init()在网络上调用该 方法时,它会在网络上应用所选的权重初始化,并允许我们将数据传递给训练。如果我们想在培训期间看到损失分数,我们也可以将听众传递给网络。
    实例化模型有一个fit()接受数据集迭代器(扩展的迭代器BaseDatasetIterator),单个DataSet或ND数组(实现INDArray)的方法。由于我们的EMNIST迭代器已经扩展了迭代器基类,我们可以直接传递它来适应。如果我们想要训练多个时代,DL4J还提供了一个MultipleEpochsIterator可以为我们处理多个时代的类。

    // create the MLN
    val network = new MultiLayerNetwork(conf)
    network.init()
    
    // pass a training listener that reports score every 10 iterations
    val eachIterations = 5
    network.addListeners(new ScoreIterationListener(eachIterations))
    
    // fit a dataset for a single epoch
    // network.fit(emnistTrain)
    
    // fit for multiple epochs
    // val numEpochs = 2
    // network.fit(new MultipleEpochsIterator(numEpochs, emnistTrain))


    6.评估模型
    Deeplearning4j公开了几种工具来评估模型的性能。您可以执行基本评估并获取精度和准确度等指标,或使用接收器操作特性(ROC)。请注意,通用ROC类适用于二进制分类器,而ROCMultiClass适用于分类器,例如我们在此构建的模型。

    A MultiLayerNetwork方便地有一些内置的方法来帮助我们进行评估。您可以将包含测试/验证数据的数据集迭代器传递给evaluate()方法。

    // evaluate basic performance
    val eval = network.evaluate(emnistTest)
    eval.accuracy()
    eval.precision()
    eval.recall()
    
    // evaluate ROC and calculate the Area Under Curve
    val roc = network.evaluateROCMultiClass(emnistTest)
    roc.calculateAverageAUC()
    
    val classIndex = 0
    roc.calculateAUC(classIndex)
    
    // optionally, you can print all stats from the evaluations
    print(eval.stats())
    print(roc.stats())
    // evaluate basic performance
    val eval = network.evaluate(emnistTest)
    eval.accuracy()
    eval.precision()
    eval.recall()
    
    // evaluate ROC and calculate the Area Under Curve
    val roc = network.evaluateROCMultiClass(emnistTest)
    roc.calculateAverageAUC()
    
    val classIndex = 0
    roc.calculateAUC(classIndex)
    
    // optionally, you can print all stats from the evaluations
    print(eval.stats())
    print(roc.stats())
  • 相关阅读:
    Atcoder Grand Contest 038 F
    洛谷 P5502
    Codeforces 1010F
    洛谷 P4621
    洛谷 P5518
    Oracle-切换当用用户的模式
    Oracle-DBV数据文件校验工具
    【转载】Oracle-通过增量备份前滚的反手解决物理备库归档缺失,损坏,gap问题
    Oracle-对比SAA与STA
    Oracle-SAA
  • 原文地址:https://www.cnblogs.com/liaohai/p/9620947.html
Copyright © 2011-2022 走看看