zoukankan      html  css  js  c++  java
  • 用keras实现基本的图像分类任务

    数据集介绍

    fashion mnist数据集是mnist的进阶版本,有10种对应的结果

    训练集有60000个,每一个都是28*28的图像,每一个对应一个标签(0-9)表示

    测试集有10000个

    代码
    import tensorflow as tf
    import keras
    import numpy as np
    import matplotlib.pyplot as plt
    
    #导入fashioin_mnist数据集
    fashion_mnist = keras.datasets.fashion_mnist
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    
    #分别于0-9对应
    class_names = ['上衣','裤子','套衫','裙子','外套','凉鞋','衬衫','运动鞋','包包','踝靴']
    
    #压缩像素值到0-1之间
    train_images = train_images / 255.0
    test_images = test_images / 255.0
    
    #查看前几个数据的图像
    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i], cmap=plt.cm.binary)
        plt.xlabel(class_names[train_labels[i]])
        
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),   #输入图像大小为28*28
        keras.layers.Dense(128, activation=tf.nn.relu),  #用relu函数作为激活函数
        keras.layers.Dense(10, activation=tf.nn.softmax)   #softmax之后输出10个值,分别表示对应的概率
    ])
    
    model.compile(optimizer=tf.train.AdamOptimizer(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    model.fit(train_images,train_labels,epochs= 10)  #运行完准确率有91.13%
    
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    
    print('Test accuracy:', test_acc)	#运行完在测试集上的准确率为88.58%
    #测试集的准确率小于训练集,说明过拟合
    

    参考

    https://www.tensorflow.org/tutorials/keras/basic_classification?hl=zh-cn

  • 相关阅读:
    10-10-12分页机制(xp)
    段间跳转之任务门
    段间跳转之TSS段
    mysql索引
    cat /proc/meminfo
    This system is not registered to Red Hat Subscription Management报错
    CentOS 6.5安装zabbix
    KVM(系统虚拟化模块)安装
    Linux时区更改
    学习ruby/rails, rvm是必不可少的工具之一
  • 原文地址:https://www.cnblogs.com/MartinLwx/p/10070618.html
Copyright © 2011-2022 走看看