zoukankan      html  css  js  c++  java
  • 5.keras-Dropout剪枝操作的应用

    keras-Dropout剪枝操作的应用

    1.载入数据以及预处理

    import numpy as np
    from keras.datasets import mnist
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import *
    from keras.optimizers import SGD
    
    import os
    
    import tensorflow as tf
    
    # 载入数据
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    
    # 预处理
    # 将(60000,28,28)转化为(600000,784),好输入展开层
    x_train = x_train.reshape(x_train.shape[0],-1)/255.0
    x_test= x_test.reshape(x_test.shape[0],-1)/255.0
    # 将输出转化为one_hot编码
    y_train = np_utils.to_categorical(y_train,num_classes=10)
    y_test = np_utils.to_categorical(y_test,num_classes=10)

    2.创建网络打印训练结果

    # 创建网络
    model = Sequential([
      Dense(units=128,input_dim=784,bias_initializer='one',activation='tanh'),
      # Dropout进行减枝,使得部分训练参数失效,避免过拟和
      Dropout(0.4),
      Dense(units=128,bias_initializer='one',activation='tanh'),
      Dropout(0.4),
      Dense(units=10,bias_initializer='one',activation='softmax')
    ]) 
    # 编译
    # 自定义优化器
    sgd = SGD(lr=0.1) model.compile(optimizer=sgd,
            
            # 运用交叉熵 loss='categorical_crossentropy', metrics=['accuracy']) model.fit(x_train,y_train,batch_size=32,epochs=10,validation_split=0.2) # 评估模型 loss,acc = model.evaluate(x_test,y_test,) print(' test loss',loss) print('test acc',acc)

    out:

    Epoch 1/10

    32/48000 [..............................] - ETA: 5:04 - loss: 2.7763 - acc: 0.1250
    576/48000 [..............................] - ETA: 21s - loss: 2.6202 - acc: 0.1354

    ......

    ......

    Epoch 10/10

    47744/48000 [============================>.] - ETA: 0s - loss: 0.1830 - acc: 0.9448
    48000/48000 [==============================] - 3s 72us/step - loss: 0.1831 - acc: 0.9449 - val_loss: 0.1210 - val_acc: 0.9649

    32/10000 [..............................] - ETA: 0s
    1824/10000 [====>.........................] - ETA: 0s
    3616/10000 [=========>....................] - ETA: 0s
    5472/10000 [===============>..............] - ETA: 0s
    7456/10000 [=====================>........] - ETA: 0s
    9440/10000 [===========================>..] - ETA: 0s
    10000/10000 [==============================] - 0s 27us/step

    test loss 0.11740412595644593
    test acc 0.9652

  • 相关阅读:
    个人作业-Alpha项目测试
    第三次作业-结对编程
    第二次作业
    第一次阅读作业
    canal同步mysql数据至es5.5.0
    工作一周年小结
    Java集合操作 遍历list并转map
    网易秋招校招编程题
    堆外内存总结
    网易秋招内推编程题题解
  • 原文地址:https://www.cnblogs.com/wigginess/p/13062814.html
Copyright © 2011-2022 走看看