zoukankan      html  css  js  c++  java
  • BackBone--Lenet

    import tensorflow as tf
    
    (train_x,train_y),(test_x,test_y) = tf.keras.datasets.mnist.load_data()

    划分验证集和测试集

    valid_x = train_x[:5000]
    valid_y = train_y[:5000]
    train_x = train_x[5000:]
    train_y = train_y[5000:]

    显示数据集

    import matplotlib.pyplot as plt
    plt.figure(figsize=(12,5))
    for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(train_x[i])
    plt.xlabel(train_y[i])
    plt.xticks([])
    plt.yticks([])
    plt.savefig('lenet-mnist')
    plt.show()

    数据归一化

    from sklearn.preprocessing import StandardScaler
    import numpy as np
    scaler = StandardScaler()
    train_x_scaled = scaler.fit_transform(train_x.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    valid_x_scaled = scaler.fit_transform(valid_x.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    test_x_scaled = scaler.fit_transform(test_x.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    train_x = tf.reshape(train_x_scaled,(train_x_scaled.shape[0],train_x_scaled.shape[1],train_x_scaled.shape[2],1))
    valid_x = tf.reshape(valid_x_scaled,(valid_x_scaled.shape[0],valid_x_scaled.shape[1],valid_x_scaled.shape[2],1))
    test_x = tf.reshape(test_x_scaled,(test_x_scaled.shape[0],test_x_scaled.shape[1],test_x_scaled.shape[2],1))

    构建lenet网络

    model = keras.models.Sequential()
    model.add(keras.layers.Conv2D(6,(5,5),activation='relu',input_shape = (28,28,1)))
    model.add(keras.layers.MaxPool2D(2,2))
    model.add(keras.layers.Conv2D(16,(5,5),activation='relu'))
    model.add(keras.layers.MaxPool2D(2,2))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(120,activation='relu'))
    model.add(keras.layers.Dense(84,activation='relu'))
    model.add(keras.layers.Dense(10,activation='sigmoid'))

    编译网络

    import os,sys
    lenet_logdir = os.path.join('lenet_logdir')
    out_model = os.path.join(lenet_logdir,'lenet_mnist.h5')
    callbacks = [keras.callbacks.TensorBoard(lenet_logdir),
                keras.callbacks.ModelCheckpoint(lenet_logdir,save_best_only=True)]
    model.compile(optimizer = 'sgd',
                 loss = 'sparse_categorical_crossentropy',
                 metrics = ['accuracy'])

    训练网络

    history = model.fit(train_x,train_y,
    epochs = 30,
    validation_data = (valid_x,valid_y),
    callbacks=callbacks)

    可视化信息

    import pandas as pd 
    pd.DataFrame(history.history).plot(figsize = (10,4))
    plt.grid(True)
    plt.gca().set_ylim(0,1)
    plt.gca().set_xlim(0,5)
    plt.show()

    模型测试

    model.evaluate(test_x,test_y)
    10000/10000 [==============================] - 1s 56us/sample - loss: 0.0352 - accuracy: 0.9899
    [0.03522909182251751, 0.9899]
  • 相关阅读:
    Jmeter设置代理,抓包之app请求
    Jmeter-测试计划,线程组,取样器,逻辑控制器,断言和监听器
    sh命令
    第二篇 Html(13章节)-a标签,img标签,列表,表格
    第一篇-Html标签中head标签,body标签中input系列,textarea和select标签
    给你任意指定生活中的一件物品,你会怎么测试?
    软件缺陷的优先级和严重性定义
    Maven学习总结(三)——使用Maven构建项目
    Maven学习总结(二)——Maven项目构建过程练习
    Maven学习总结(一)——Maven入门
  • 原文地址:https://www.cnblogs.com/peiziming/p/13232441.html
Copyright © 2011-2022 走看看