zoukankan      html  css  js  c++  java
  • 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)...

    # -*- coding: utf-8 -*-
    import numpy as np
    np.random.seed(1337)
    
    from keras.datasets import mnist
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import SimpleRNN,Activation,Dense
    from keras.optimizers import Adam
    
    TIME_STEPS = 28 #图片的高
    INPUT_SIZE = 28 #图片的行
    BATCH_SIZE = 50 #每批训练多少图片
    BATCH_INDEX = 0 
    OUTPUT_SIZE = 10
    CELL_SIZE = 50
    LR = 0.001
    
    #下载mnist数据集
    # X shape (60000,28*28) ,y shape (10000)
    (X_train,y_train),(X_test,y_test) = mnist.load_data()
    
    # 数据预处理
    X_train = X_train.reshape(-1,28,28)/255
    X_test = X_test.reshape(-1,28,28)/255
    y_train = np_utils.to_categorical(y_train,num_classes=10)
    y_test = np_utils.to_categorical(y_test,num_classes=10)
    
    
    # 建模型
    model = Sequential()
    # RNN
    model.add(SimpleRNN(
            batch_input_shape=(None,TIME_STEPS,INPUT_SIZE),# 每次训练的量(None表示全部),图片大小
            output_dim=CELL_SIZE,
            ))
    # 输出层
    model.add(Dense(OUTPUT_SIZE))
    model.add(Activation('softmax'))
    
    # 优化器
    adam = Adam(LR)
    model.compile(optimizer=adam,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    # 训练
    for step in range(4001):
        X_batch=X_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:,:]
        Y_batch=y_train[BATCH_INDEX:BATCH_SIZE+BATCH_INDEX,:]
        cost = model.train_on_batch(X_batch,Y_batch)
        
        BATCH_INDEX += BATCH_SIZE
        BATCH_INDEX = 0 if BATCH_INDEX>=X_train.shape[0] else BATCH_INDEX
        
        if step % 500 == 0:
            cost,accuracy = model.evaluate(X_test,y_test,batch_size=y_test.shape[0],verbose=False)
            print('test cost: ',cost,'test accuracy: ',accuracy)

  • 相关阅读:
    AUC ROC PR曲线
    L1,L2范数和正则化 到lasso ridge regression
    目标函数和损失函数
    logistic回归和线性回归
    [转]如何处理不均衡数据?
    将Maven项目打包成可执行 jar文件(引用第三方jar)
    Postgresql VACUUM COPY等
    linux安装xgboost
    java社区推荐
    rabbitmq-java api
  • 原文地址:https://www.cnblogs.com/caiyishuai/p/13270688.html
Copyright © 2011-2022 走看看