zoukankan      html  css  js  c++  java
  • 第一个神经网络

    1.keras训练神经网络的一般步骤

    • 导入数据,做数据处理,使数据符合模型要求
    • 定义网络结构
    • 定义损失函数、优化器、监控指标
    • 训练模型
    • 图形化

    2.使用MNIST数据集的一个例子

    from keras.datasets import mnist
    from keras import models
    from keras import layers
    from keras.utils import to_categorical
    import matplotlib.pyplot as plt
    #导入数据
    (train_x,train_y),(test_x,test_y) = mnist.load_data()
    #每个神经元表示一个像素,所以把3维数据转换成2维数据
    train_x = train_x.reshape(60000,28*28)
    train_x = train_x.astype('float32') /255
    test_x = test_x.reshape(10000,28*28)
    test_x = test_x.astype('float32') /255
    #把数据中类别用向量表示,以便使用交叉熵,6---->[0,0,0,0,0,0,1,0,0,0]
    train_y = to_categorical(train_y)
    test_y = to_categorical(test_y)
    #定义网络结构
    network = models.Sequential()
    #Dense表示全连接 network.add(layers.Dense(
    512,activation="relu",input_shape=(28*28,))) network.add(layers.Dense(10,activation="softmax")) #定义损失函数、优化器、监控指标 network.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=['accuracy']) #训练模型 history = network.fit(train_x,train_y,batch_size=128,epochs=11,validation_data=(test_x,test_y)) history_dict = history.history loss = history_dict['loss'] val_loss = history_dict['val_loss'] acc = history_dict['acc'] val_acc = history_dict['val_acc'] epochs = range(1,11) #loss的图 plt.subplot(121) plt.plot(epochs,loss,'g',label = 'Training loss') plt.plot(epochs,val_loss,'b',label = 'Validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') #显示图例 plt.legend() plt.subplot(122) plt.plot(epochs,acc,'g',label = 'Training accuracy') plt.plot(epochs,val_acc,'b',label = 'Validation accuracy') plt.xlabel('Epochs') plt.ylabel('accuracy') plt.legend() plt.show()

    结果图如下

    可以发现随着轮数的提升,训练集的损失在不断减少,验证集的损失在轮数为5左右达到了最小;同理随着轮数提升,训练集的精度在不断提升,验证集的精度在轮数为5左右达到最大;提示我们有可能发生了过拟合,可以把轮数改为5

  • 相关阅读:
    一些Vim使用的小技巧
    virtualbox centos安装增强工具和Centos与VirtualBox共享文件夹设置
    (转) centos7 RPM包之rpm命令
    (转)Navicat_12安装与破解,亲测可用!!!
    (转)2019年 React 新手学习指南 – 从 React 学习线路图说开去
    (转)react 项目构建
    (转)python3:类方法,静态方法和实例方法以及应用场景
    (转)SQLAlchemy入门和进阶
    (转)面向对象(深入)|python描述器详解
    (转)CentOS 7.6 上编译安装httpd 2.4.38
  • 原文地址:https://www.cnblogs.com/vshen999/p/10444358.html
Copyright © 2011-2022 走看看