zoukankan      html  css  js  c++  java
  • 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类

    1.载入数据以及预处理

    import numpy as np
    from keras.datasets import mnist
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import *
    from keras.optimizers import SGD
    
    import os
    
    import tensorflow as tf
    
    # 载入数据
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    
    # 预处理
    # 将(60000,28,28)转化为(600000,784),好输入展开层
    x_train = x_train.reshape(x_train.shape[0],-1)/255.0
    x_test= x_test.reshape(x_test.shape[0],-1)/255.0
    # 将输出转化为one_hot编码
    y_train = np_utils.to_categorical(y_train,num_classes=10)
    y_test = np_utils.to_categorical(y_test,num_classes=10)

    2.创建网络打印训练结果

    # 创建网络
    model = Sequential([
        # 输入784输出10个
        Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
    ])
    # 编译
    # 自定义优化器
    sgd = SGD(lr=0.1)
    model.compile(optimizer=sgd,
                  loss='mse',
                  # 得到训练过程中的准确率
                  metrics=['accuracy'])
    
    model.fit(x_train,y_train,batch_size=32,epochs=10,validation_split=0.2)
    
    # 评估模型
    loss,acc = model.evaluate(x_test,y_test,)
    print('
    test loss',loss)
    print('test acc',acc)

    out:

    Epoch 1/10

    32/48000 [..............................] - ETA: 2:27 - loss: 0.0905 - acc: 0.1875
    1248/48000 [..............................] - ETA: 5s - loss: 0.0907 - acc: 0.1346

    ......

    ......

    Epoch 10/10

    45952/48000 [===========================>..] - ETA: 0s - loss: 0.0164 - acc: 0.9005
    47616/48000 [============================>.] - ETA: 0s - loss: 0.0163 - acc: 0.9008
    48000/48000 [==============================] - 2s 37us/step - loss: 0.0163 - acc: 0.9010 - val_loss: 0.0149 - val_acc: 0.9084

    32/10000 [..............................] - ETA: 4s
    3360/10000 [=========>....................] - ETA: 0s
    5824/10000 [================>.............] - ETA: 0s
    8512/10000 [========================>.....] - ETA: 0s
    10000/10000 [==============================] - 0s 20us/step

    test loss 0.015059704356454312
    test acc 0.908

  • 相关阅读:
    将execel表格的数据导入到mysql数据库
    清明听雨
    h5调用底层接口的一些知识
    微信小程序从零开始开发步骤(一)搭建开发环境
    Matplotlib
    Numpy
    pandas
    6 MapReduce的理解
    静态链表
    单链表
  • 原文地址:https://www.cnblogs.com/wigginess/p/13062739.html
Copyright © 2011-2022 走看看