zoukankan      html  css  js  c++  java
  • keras 极简搭建VGG16 手写数字识别

    使用VGG16网络 完成迁移学习案例

    from keras.applications.vgg16 import VGG16
    from keras.models import Sequential
    from keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Dense
    from keras.optimizers import SGD
    from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
    import numpy as np
    from keras.utils import  np_utils
    import cv2
    import pickle
    import matplotlib.pyplot as plt
    from keras.datasets import mnist
    
    # 得到适合网络的数据
    (X_train_data, Y_train), (X_test_data, Y_test) = mnist.load_data()  # 下载数据
    X_train_data = X_train_data.astype('float32')  # uint8-->float32
    X_test_data = X_test_data.astype('float32')
    X_train_data /= 255  # 归一化到0~1区间
    X_test_data /= 255
    # (60000, 48, 48, 3)
    X_train = []
    # (10000, 48, 48, 3)
    X_test = []
    # 把(27, 27, 1)维的数据转化成(48, 48, 3)维的数据
    for i in range(X_train_data.shape[0]):
        X_train.append(cv2.cvtColor(cv2.resize(X_train_data[i], (48, 48)), cv2.COLOR_GRAY2RGB))
    for i in range(X_test_data.shape[0]):
        X_test.append(cv2.cvtColor(cv2.resize(X_test_data[i], (48, 48)), cv2.COLOR_GRAY2RGB))
    
    X_train = np.array(X_train)
    X_test = np.array(X_test)
    # 独热编码
    y_train = np_utils.to_categorical(Y_train, num_classes=10)
    y_test = np_utils.to_categorical(Y_test, num_classes=10)
    
    # 构建网络
    vgg16_model = VGG16(weights='imagenet', include_top=False, input_shape=(48, 48, 3))
    for layer in vgg16_model.layers:
        layer.trainable = False # 别去调整之前的卷积层的参数
    
    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
    top_model.add(Dense(512, activation='relu'))
    top_model.add(Dropout(0.4))
    top_model.add(Dense(10, activation='softmax'))
    
    model = Sequential()
    model.add(vgg16_model)
    model.add(top_model)
    sgd = SGD(learning_rate=0.05, decay=1e-5)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
    model.fit(X_train, y_train, batch_size=128, epochs=15)
    model.evaluate(X_test, y_test)

    我跑了30轮数据,测试集上准确率在0.9833左右

  • 相关阅读:
    设计模式概述
    Android之.9.png图片的制作与使用
    2015-4-3~2015-5-28 第四届全国大学生软件设计大赛《解密陌生人》项目总结
    排序算法之快速排序
    AsyncTask那些事(更新中...)
    经典Android面试题
    import第三方库的头文件找不到的错误
    点击某个按钮在tableView某个位置动态插入一行cell
    NSUserDefaults:熟悉与陌生(转)
    更改UIsearchbar 的背景和cancel按钮(转)
  • 原文地址:https://www.cnblogs.com/abc23/p/12300712.html
Copyright © 2011-2022 走看看