zoukankan      html  css  js  c++  java
  • Keras手写识别例子(1)----softmax

    转自:https://morvanzhou.github.io/tutorials/machine-learning/keras/2-2-classifier/#测试模型

    下载数据

    # download the mnist to the path '~/.keras/datasets/' if it is the first time to be called
    # X shape (60,000 28x28), y shape (10,000, )
    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    data预处理:

    X_train = X_train.reshape(X_train.shape[0], -1) / 255.   # normalize
    X_test = X_test.reshape(X_test.shape[0], -1) / 255.      # normalize
    y_train = np_utils.to_categorical(y_train, num_classes=10)
    y_test = np_utils.to_categorical(y_test, num_classes=10)

    导入包:

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("./", one_hot=True)
    X_train=mnist.train.images
    Y_train=mnist.train.labels
    X_test=mnist.test.images
    Y_test=mnist.test.labels

    因为(X_train, y_train), (X_test, y_test) = mnist.load_data()需从网上下载数据,由于网络限制,下载失败。

    可以先在官网yann.lecun.com/exdb/mnist/上下载四个数据(train-images-idx3-ubyte.gz、train-labels-idx1-ubyte.gz、t10k-images-idx3-ubyte.gz、t10k-labels-idx1-ubyte.gz

    在当前目录,不要解压!

    #input_data.py该模块在tensorflow.examples.tutorials.mnist下,直接加载来读取上面四个压缩包。

    #四个压缩包形式为特殊形式。非图片和标签,要解析。

    from tensorflow.examples.tutorials.mnist import input_data

    #加载数据路径为"./",为当前路径,自动加载数据,用one-hot方式处理好数据。

    #read_data_sets是input_data.py里面的一个函数,主要是将数据解压之后,放到对应的位置。 第一个参数为路径,写"./"表示当前路径,其会判断该路径下有没有数据,没有的话会自动下载数据。

    mnist = input_data.read_data_sets("./", one_hot=True)  

    相关的包:

    model.Sequential():用来一层一层的去建立神经层。

    layers.Dense,表示这个神经层是全连接层。

    layers.Activation,激励函数

    optimizers.RMSprop,优化器采用RMSprop,加速神经网络训练方法。

    Keras工作流程:

    1. 定义训练数据:输入张量和目标张量
    2. 定义层组成的网络(或模型),将输入映射到目标
    3. 配置学习过程:选择损失函数、优化器和需要监控的指标
    4. 调用模型的fit方法在训练数据上进行迭代

    代码:

    import numpy as np
    np.random.seed(1337)  # for reproducibility
    from keras.datasets import mnist
    from keras.models import Sequential from keras.layers import Dense, Activation from keras.optimizers import RMSprop #读取数据,其中,X_train为55000*784,Y_train为55000*10,X_test为10000*784,Y_test大小为10000*10. from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("./", one_hot=True) X_train=mnist.train.images Y_train=mnist.train.labels X_test=mnist.test.images Y_test=mnist.test.labels

    #建立神经网络模型,一共两层,第一层输入784个变量,输出为32,激活函数为relu,第二层输入是上层的输出32,输出为10,激活函数为softmax。 model = Sequential([ Dense(32, input_dim=784), Activation('relu'), Dense(10), Activation('softmax'), ]) #采用RMSprop来求解模型,设学习率lr为0.001,以及别的参数。 rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0) #激活模型,优化器为rmsprop,损失函数为交叉熵,metric,里面可以放入需要计算的,比如cost、accuracy、score等 model.compile(optimizer=rmsprop, loss='categorical_crossentropy', metrics=['accuracy']) #训练网络,用fit函数,导入数据,训练次数为20,每批处理32个 model.fit(X_train, Y_train, nb_epoch=20, batch_size=32) #测试模型 print(' Testing ------------') # Evaluate the model with the metrics we defined earlier loss, accuracy = model.evaluate(X_test, Y_test) print('test loss: ', loss) print('test accuracy: ', accuracy)

     结果:

     
  • 相关阅读:
    webservice呈现调用导致呈现当机.
    【ES6】Promise用法
    腾讯地图JavaScript API调用
    【微信小程序】微信开发者工具快捷键汇总
    【微信小程序】小程序模拟调用本地json接口数据
    【微信小程序】小程序系统API
    【git和GitHub】分布式版本控制Git和代码远程仓库GitHub
    【微信小程序】组件化开发
    【微信小程序】小程序开发注意事项
    【微信小程序】引入外部js 方法
  • 原文地址:https://www.cnblogs.com/Lee-yl/p/8572443.html
Copyright © 2011-2022 走看看