zoukankan      html  css  js  c++  java
  • Keras学习笔记——Hello Keras

    最近几年,随着AlphaGo的崛起,深度学习开始出现在各个领域,比如无人车、图像识别、物体检测、推荐系统、语音识别、聊天问答等等。因此具备深度学习的知识并能应用实践,已经成为很多开发者包括博主本人的下一个目标了。

    目前最流行的框架莫过于Tensorflow了,但是只要接触过它的人,就知道它使用起来是多么让人恐惧。Tensorflow对我们来说,仿佛是一门高深的Deep Learning学习语言,需要具备很深的机器学习和深度学习功底,才能玩得转。

    Keras正是在这种背景下应运而生的,它是一个对开发者很友好的框架,底层可以基于TensorFlow和Theano,使用起来仿佛是在搭积木。只要不停的添加已有的“层”,就可以实现各种复杂的深度网络模型。

    因此,开发者需要熟悉的不过是两点:如何搭建积木?都有什么积木可以用?

    安装

    安装的步骤直接按照官方文档来就行了,我笔记本的环境已经杂乱不堪,没有办法一步一步记录安装配置了。主要是安装python3.6,然后各种pip install就行了。

    参考文档:http://keras-cn.readthedocs.io/en/latest/for_beginners/keras_linux/

    基础概念

    在使用Keras前,首先要了解Keras里面关于模型如何创建。在上面可爱的小盆友的图片中,想要把积木罗列在一起,需要一个中心的木棍。那么Sequential就可以看做是这个木棍。

    剩下的工作就是add不同的层就行了:

    model = Sequential()
    model.add(Dense(32, input_shape=(784,)))
    model.add(Activation('relu'))
    

    建立好model后,相当于我们定义好了逻辑模型。此时就需要编译模型,生成对应的代码:

    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    

    其中optimizer是参数优化的方法,loss是损失函数的定义,metrics是衡量模型效果的指标。

    最后再灌入数据进行训练即可:

    model.fit(data, labels, epochs=10, batch_size=32)
    

    完整的例子

    代码已经上传到github:https://github.com/xinghalo/keras-examples/blob/master/keras-cn/mnist/mnist_mlp.py
    很多人hello world跑不通是因为网络问题,不能下载到对应的数据集。我这里把数据集也上传到对应的目录下了,修改对应的path即可。

    from __future__ import print_function
    
    import keras
    from keras.datasets import mnist
    from keras.models import Sequential
    from keras.layers import Dense, Dropout
    from keras.optimizers import RMSprop
    
    batch_size = 128
    num_classes = 10
    epochs = 3
    
    # the data, split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data("/Users/xingoo/PycharmProjects/keras-examples/keras-cn/mnist/mnist.npz")
    
    x_train = x_train.reshape(60000, 784)
    x_test = x_test.reshape(10000, 784)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')
    
    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    
    model = Sequential()
    model.add(Dense(512, activation='relu', input_shape=(784,)))
    model.add(Dropout(0.2))
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(num_classes, activation='softmax'))
    
    model.summary()
    
    model.compile(loss='categorical_crossentropy',
                  optimizer=RMSprop(),
                  metrics=['accuracy'])
    
    history = model.fit(x_train, y_train,
                        batch_size=batch_size,
                        epochs=epochs,
                        verbose=1,
                        validation_data=(x_test, y_test))
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])
    

    运行效果

    Using TensorFlow backend.
    /Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
      return f(*args, **kwds)
    60000 train samples
    10000 test samples
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_1 (Dense)              (None, 512)               401920    
    _________________________________________________________________
    dropout_1 (Dropout)          (None, 512)               0         
    _________________________________________________________________
    dense_2 (Dense)              (None, 512)               262656    
    _________________________________________________________________
    dropout_2 (Dropout)          (None, 512)               0         
    _________________________________________________________________
    dense_3 (Dense)              (None, 10)                5130      
    =================================================================
    Total params: 669,706
    Trainable params: 669,706
    Non-trainable params: 0
    _________________________________________________________________
    2018-05-25 17:15:22.294036: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
    Train on 60000 samples, validate on 10000 samples
    Epoch 1/3
      128/60000 [..............................] - ETA: 166s - loss: 2.4237 - acc: 0.0703
      640/60000 [..............................] - ETA: 38s - loss: 1.8688 - acc: 0.3812 
     1152/60000 [..............................] - ETA: 23s - loss: 1.5497 - acc: 0.5087
     1664/60000 [..............................] - ETA: 18s - loss: 1.3466 - acc: 0.5655
     2176/60000 [>.............................] - ETA: 15s - loss: 1.1902 - acc: 0.6167
     2688/60000 [>.............................] - ETA: 13s - loss: 1.0736 - acc: 0.6536
     3200/60000 [>.............................] - ETA: 12s - loss: 0.9968 - acc: 0.6778
     3712/60000 [>.............................] - ETA: 11s - loss: 0.9323 - acc: 0.7002
     4096/60000 [=>............................] - ETA: 11s - loss: 0.8971 - acc: 0.7109
    ...
    51328/60000 [========================>.....] - ETA: 0s - loss: 0.0733 - acc: 0.9775
    51840/60000 [========================>.....] - ETA: 0s - loss: 0.0733 - acc: 0.9774
    52352/60000 [=========================>....] - ETA: 0s - loss: 0.0735 - acc: 0.9774
    52864/60000 [=========================>....] - ETA: 0s - loss: 0.0733 - acc: 0.9775
    53376/60000 [=========================>....] - ETA: 0s - loss: 0.0736 - acc: 0.9774
    53888/60000 [=========================>....] - ETA: 0s - loss: 0.0734 - acc: 0.9775
    54400/60000 [==========================>...] - ETA: 0s - loss: 0.0736 - acc: 0.9774
    54912/60000 [==========================>...] - ETA: 0s - loss: 0.0740 - acc: 0.9773
    55424/60000 [==========================>...] - ETA: 0s - loss: 0.0744 - acc: 0.9772
    55936/60000 [==========================>...] - ETA: 0s - loss: 0.0746 - acc: 0.9771
    56448/60000 [===========================>..] - ETA: 0s - loss: 0.0749 - acc: 0.9771
    56960/60000 [===========================>..] - ETA: 0s - loss: 0.0751 - acc: 0.9772
    57472/60000 [===========================>..] - ETA: 0s - loss: 0.0756 - acc: 0.9772
    57984/60000 [===========================>..] - ETA: 0s - loss: 0.0754 - acc: 0.9772
    58496/60000 [============================>.] - ETA: 0s - loss: 0.0750 - acc: 0.9773
    59008/60000 [============================>.] - ETA: 0s - loss: 0.0750 - acc: 0.9773
    59520/60000 [============================>.] - ETA: 0s - loss: 0.0749 - acc: 0.9774
    60000/60000 [==============================] - 7s - loss: 0.0749 - acc: 0.9774 - val_loss: 0.0819 - val_acc: 0.9768
    Test loss: 0.0819479118524
    Test accuracy: 0.9768
    

    参考

    1. Keras中文官方文档:http://keras-cn.readthedocs.io/en/latest/getting_started/sequential_model/
    2. Keras github examples:https://github.com/keras-team/keras/blob/master/examples/mnist_mlp.py
    3. 神经网络(一):概念:https://blog.csdn.net/xierhacker/article/details/51771428
    4. 神经网络(二):感知机:https://blog.csdn.net/xierhacker/article/details/51816484
    5. 深度学习笔记二:多层感知机(MLP)与神经网络结构:https://blog.csdn.net/xierhacker/article/details/53282038
    6. 多层感知机:Multi-Layer Perceptron:https://blog.csdn.net/xholes/article/details/78461164
  • 相关阅读:
    第26月第26天 Domain=AVFoundationErrorDomain Code=-11850
    第26月第25天 ubuntu openjdk-8-jdk jretty
    第26月第23天 nsobject 单例 CFAbsoluteTimeGetCurrent
    第26月第22天 iOS瘦身之armv7 armv7s arm64选用 iOS crash
    第26月第20天 springboot
    第26月第18天 mybatis_spring_mvc pom
    python中的字符数字之间的转换函数
    python if else elif statement
    python 赋值魔法
    python print import使用
  • 原文地址:https://www.cnblogs.com/xing901022/p/9090969.html
Copyright © 2011-2022 走看看