zoukankan      html  css  js  c++  java
  • keras 极简搭建VGG16 手写数字识别

    使用VGG16网络 完成迁移学习案例

    from keras.applications.vgg16 import VGG16
    from keras.models import Sequential
    from keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Dense
    from keras.optimizers import SGD
    from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
    import numpy as np
    from keras.utils import  np_utils
    import cv2
    import pickle
    import matplotlib.pyplot as plt
    from keras.datasets import mnist
    
    # 得到适合网络的数据
    (X_train_data, Y_train), (X_test_data, Y_test) = mnist.load_data()  # 下载数据
    X_train_data = X_train_data.astype('float32')  # uint8-->float32
    X_test_data = X_test_data.astype('float32')
    X_train_data /= 255  # 归一化到0~1区间
    X_test_data /= 255
    # (60000, 48, 48, 3)
    X_train = []
    # (10000, 48, 48, 3)
    X_test = []
    # 把(27, 27, 1)维的数据转化成(48, 48, 3)维的数据
    for i in range(X_train_data.shape[0]):
        X_train.append(cv2.cvtColor(cv2.resize(X_train_data[i], (48, 48)), cv2.COLOR_GRAY2RGB))
    for i in range(X_test_data.shape[0]):
        X_test.append(cv2.cvtColor(cv2.resize(X_test_data[i], (48, 48)), cv2.COLOR_GRAY2RGB))
    
    X_train = np.array(X_train)
    X_test = np.array(X_test)
    # 独热编码
    y_train = np_utils.to_categorical(Y_train, num_classes=10)
    y_test = np_utils.to_categorical(Y_test, num_classes=10)
    
    # 构建网络
    vgg16_model = VGG16(weights='imagenet', include_top=False, input_shape=(48, 48, 3))
    for layer in vgg16_model.layers:
        layer.trainable = False # 别去调整之前的卷积层的参数
    
    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
    top_model.add(Dense(512, activation='relu'))
    top_model.add(Dropout(0.4))
    top_model.add(Dense(10, activation='softmax'))
    
    model = Sequential()
    model.add(vgg16_model)
    model.add(top_model)
    sgd = SGD(learning_rate=0.05, decay=1e-5)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
    model.fit(X_train, y_train, batch_size=128, epochs=15)
    model.evaluate(X_test, y_test)

    我跑了30轮数据,测试集上准确率在0.9833左右

  • 相关阅读:
    再谈spark部署搭建和企业级项目接轨的入门经验(博主推荐)
    CSS基础3——使用CSS格式化元素内容的字体
    利用MySQL 的GROUP_CONCAT函数实现聚合乘法
    POJ Octal Fractions(JAVA水过)
    组件接口(API)设计指南-文件夹
    Nginx 因 Selinux 服务导致无法远程訪问
    host字段变复杂了
    hdu 1251 统计难题 初识map
    “那个人样子好怪。”“我也看到了,他好像一条狗。”
    pomelo 协议
  • 原文地址:https://www.cnblogs.com/abc23/p/12300712.html
Copyright © 2011-2022 走看看