zoukankan      html  css  js  c++  java
  • cifar10数据集训练

    下载数据集

    Cifar10数据集总共有6万张32*32像素点的彩色图片和标签,涵盖十个分类:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。

    其中5万张用于训练,1万张用于测试。

    import tensorflow as tf
    from tensorflow import keras
    from matplotlib import pyplot as plt
    import numpy as np
    from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense,Dropout
    
    cifar10 = keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    搭建网络结构

    model = keras.models.Sequential([
        Conv2D(128, (3, 3), activation='relu',padding='same'),
        keras.layers.BatchNormalization(),
        MaxPool2D((2, 2)),
        Dropout(0.3),
        Conv2D(256, (3, 3), activation='relu',padding='same'),
        keras.layers.BatchNormalization(),
        MaxPool2D((2, 2)),
        Dropout(0.3),
        Conv2D(512, (3, 3), activation='relu',padding='same'),
        keras.layers.BatchNormalization(),
        MaxPool2D((2, 2)),
        Flatten(),
        Dropout(0.5),
        Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(0.1)),
        Dropout(0.5),
        Dense(10, activation='softmax')
    ])

    编译模型

    model.compile(optimizer=keras.optimizers.Adam(lr=0.0001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])

    训练模型

    history = model.fit(x_train, y_train, epochs=100, batch_size=16,verbose=1,validation_data=(x_test, y_test),validation_freq=1)

    可视化acc/loss曲线

    #显示训练集和测试集的acc和loss曲线
    plt.rcParams['font.sans-serif']=['SimHei']
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='训练Acc')
    plt.plot(val_acc, label='测试Acc')
    plt.title('Acc曲线')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='训练Loss')
    plt.plot(val_loss, label='测试Loss')
    plt.title('Loss曲线')
    plt.legend()
    plt.show()

  • 相关阅读:
    gbk学习笔记
    在freebsd下编译nodejs,出现无法找到execinfo.h头文件的错误
    php 截取GBK文档某个位置开始的n个字符
    linux下,phpstorm配置oracle jdk
    gb2312学习笔记
    freebsd下vim默认的vi操作方式太难用,可通过启用vim自带配置文件解决
    freebsd通过ssh远程登陆慢,用户认证时间长解决办法
    php输出全部gb2312编码内的汉字
    visibility:hidden 与 display:none 的区别
    java 实现文件/文件夹复制、删除、移动(二)
  • 原文地址:https://www.cnblogs.com/fengyumeng/p/13991729.html
Copyright © 2011-2022 走看看