zoukankan      html  css  js  c++  java
  • keras mnist

    基于tensorflow2

    最简单的keras demo,做一个备份,以后有需要直接复制

    模型结构

    代码 

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras import layers
    from tensorflow.keras import models
    from tensorflow.keras import optimizers
    from tensorflow.keras import datasets
    from tensorflow.python.keras.utils import np_utils
    
    
    (x_train,y_train),(x_test,y_test)=datasets.mnist.load_data()
    print(x_train.shape,x_test.shape)
    print(y_train.shape,y_test.shape)
    
    
    x_train=x_train.reshape(x_train.shape[0],x_train.shape[1],x_train.shape[2],1)
    x_test=x_test.reshape(x_test.shape[0],x_test.shape[1],x_test.shape[2],1)
    x_train=x_train/255
    x_test=x_test/255
    y_train=np_utils.to_categorical(y_train)
    y_test=np_utils.to_categorical(y_test)
    print(x_train.shape,x_test.shape)
    print(y_train.shape,y_test.shape)
    
    
    inp_img=layers.Input(shape=(28,28,1))
    layer1_conv=layers.Conv2D(32,(3,3),padding='same',activation='relu')(inp_img)
    layer1_pool=layers.AveragePooling2D((2,2))(layer1_conv) 
    layer2_conv=layers.Conv2D(64,(3,3),padding='same',activation='relu')(layer1_pool)
    layer2_pool=layers.AveragePooling2D((2,2))(layer2_conv)  
    layer3_conv=layers.Conv2D(128,(3,3),padding='same',activation='relu')(layer2_pool)
    layer3_pool=layers.AveragePooling2D((7,7))(layer3_conv)   
    layer4_fc=layers.Flatten()(layer3_pool)
    pred=layers.Dense(10,activation='softmax')(layer4_fc)
    model=models.Model(inp_img,pred)
    adam=optimizers.Adam(lr=0.01)
    model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['acc'])
    model.summary()
    
    
    model.fit(x_train,y=y_train,epochs=10,batch_size=32)
    model.save('mnist_cnn.h5')
    loss,acc=model.evaluate(x_test,y_test)
    print('
    test loss: ',loss)
    print('
    test accuracy: ',acc)
    无情的摸鱼机器
  • 相关阅读:
    bzoj4262
    bzoj3252
    海蜇?海蜇!
    AGC018F
    java数据类型;常量与变量;类型转化;
    java 基础,查看jar包源码,JD-GUI
    性能测试报告
    如何防止http请求数据被篡改
    支付业务,测试遇到请求超时怎么处理;支付业务流程;异步通知和同步通知;
    fiddler使用;
  • 原文地址:https://www.cnblogs.com/wangtianning1223/p/14436368.html
Copyright © 2011-2022 走看看