zoukankan      html  css  js  c++  java
  • tf.keras训练iris数据集

    import tensorflow as tf
    import os
    from sklearn import datasets
    import numpy as np
    
    # 加载数据集
    """
    其中测试集的输入特征 x_test 和标签 y_test 可以像 x_train 和 y_train 一样直接从数据集获取, 也可以如上述在 fit 中按比例从训练集中划分,
    本例选择从训练集中划分,所以只需加载 x_train, y_train 即可
    """
    x_train = datasets.load_iris().data
    y_train = datasets.load_iris().target
    
    # 数据集乱序
    np.random.seed(116)
    np.random.shuffle(x_train)
    np.random.seed(116)
    np.random.shuffle(y_train)
    tf.random.set_seed(116)
    
    # 逐层搭建网络结构
    """
    使用了单层全连接网络,第一个参数表示神经元个数,第二个参数表示网络所使用的激活函数,第三个参数表示选用的正则化方法。
    """
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    ])
    
    # 配置训练方法
    """
    本 例 使 用 SGD 优 化 器 , 并 将 学 习 率 设 置 为 0.1 , 选 择SparseCategoricalCrossentrop 作为损失函数。 
    由于神经网络输出使用了softmax 激活函数,使得输出是概率分布,而不是原始输出, 
    所以需要将from_logits 参数设置为 False。 鸢尾花数据集给的标签是 0, 1, 2 这样的数值,而 网络前向传播的输出为概率分布 ,
    所 以 metrics 需 要 设 置 为sparse_categorical_accuracy。
    """
    model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
                  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy']
                 )
    
    # 模型训练
    """
    在 fit 中执行训练过程,x_train,y_train 分别表示网络的输入特征和标签,batch_size表示一次喂入神经网络的数据量, epochs表示数据集的迭代次数,
    validation_split 表示数据集中测试集的划分比例, validation_freq 表示每迭代 20 次在测试集上测试一次准确率。
    """
    model.fit(x_train, y_train, epochs=500, batch_size=32,validation_freq=20, validation_split=0.2)
    
    # 打印网络结构,统计参数数目
    model.summary()
    

      输出结果:

     

     可以看到对于一个输入为4输出为3的全连接网络,参数总共是15个

  • 相关阅读:
    sqlserver中判断表或临时表是否存在
    Delphi 简单方法搜索定位TreeView项
    hdu 2010 水仙花数
    hdu 1061 Rightmost Digit
    hdu 2041 超级楼梯
    hdu 2012 素数判定
    hdu 1425 sort
    hdu 1071 The area
    hdu 1005 Number Sequence
    hdu 1021 Fibonacci Again
  • 原文地址:https://www.cnblogs.com/GumpYan/p/13564218.html
Copyright © 2011-2022 走看看