zoukankan      html  css  js  c++  java
  • 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)...

    # -*- coding: utf-8 -*-
    import numpy as np
    np.random.seed(1337)
    
    from keras.datasets import mnist
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import SimpleRNN,Activation,Dense
    from keras.optimizers import Adam
    
    TIME_STEPS = 28 #图片的高
    INPUT_SIZE = 28 #图片的行
    BATCH_SIZE = 50 #每批训练多少图片
    BATCH_INDEX = 0 
    OUTPUT_SIZE = 10
    CELL_SIZE = 50
    LR = 0.001
    
    #下载mnist数据集
    # X shape (60000,28*28) ,y shape (10000)
    (X_train,y_train),(X_test,y_test) = mnist.load_data()
    
    # 数据预处理
    X_train = X_train.reshape(-1,28,28)/255
    X_test = X_test.reshape(-1,28,28)/255
    y_train = np_utils.to_categorical(y_train,num_classes=10)
    y_test = np_utils.to_categorical(y_test,num_classes=10)
    
    
    # 建模型
    model = Sequential()
    # RNN
    model.add(SimpleRNN(
            batch_input_shape=(None,TIME_STEPS,INPUT_SIZE),# 每次训练的量(None表示全部),图片大小
            output_dim=CELL_SIZE,
            ))
    # 输出层
    model.add(Dense(OUTPUT_SIZE))
    model.add(Activation('softmax'))
    
    # 优化器
    adam = Adam(LR)
    model.compile(optimizer=adam,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    # 训练
    for step in range(4001):
        X_batch=X_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:,:]
        Y_batch=y_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:]
        cost = model.train_on_batch(X_batch,Y_batch)
        
        BATCH_INDEX += BATCH_SIZE
        BATCH_INDEX = 0 if BATCH_INDEX>=X_train.shape[0] else BATCH_INDEX
        
        if step % 500 == 0:
            cost,accuracy = model.evaluate(X_test,y_test,batch_size=y_test.shape[0],verbose=False)
            print('test cost: ',cost,'test accuracy: ',accuracy)

  • 相关阅读:
    将Linux下python默认版本切换成替代版本
    ubuntu下卸载python2和升级python3.5
    Linux下安装theano
    梯度下降法
    使用Matlab实现对图片的缩放
    matlab 中的删除文件
    解决aws ec2的centos7设置时区无效
    yum安装redis5/mq/consul
    django web应用runserver模式下cpu占用高解决办法
    N1如何完美刷入armbian系统教程
  • 原文地址:https://www.cnblogs.com/caiyishuai/p/13270688.html
Copyright © 2011-2022 走看看