zoukankan      html  css  js  c++  java
  • Tensorflow2.0笔记05——程序实现鸢尾花数据集分类

    Tensorflow2.0笔记

    本博客为Tensorflow2.0学习笔记,感谢北京大学微电子学院曹建老师

    4 程序实现鸢尾花数据集分类

    4.1 数据集回顾

    ​ 先回顾鸢尾花数据集,其提供了 150 组鸢尾花数据,每组包括鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽 4 个输入特征,同时还给出了这一组特征对应的鸢尾花类别。类别包括狗尾鸢尾、杂色鸢尾、弗吉尼亚鸢尾三类, 分别用数字0、1、2 表示。使用此数据集代码如下:

    from sklearn.datasets import load_iris
    x_data = datasets.load_iris().data    # 返回 iris 数据集所有输入特征
    y_data = datasets.load_iris().target   # 返回 iris 数据集所有标签
    

    即从 sklearn 包中导出数据集,将输入特征赋值给 x_data 变量,将对应标签赋值给 y_data 变量。

    4.2 程序实现

    ​ 我们用神经网络实现鸢尾花分类仅需要三步:

    ​ (1) 准备数据,包括数据集读入、数据集乱序,把训练集和测试集中的数据配成输入特征和标签对,生成 train 和 test 即永不相见的训练集和测试集;

    ​ (2) 搭建网络,定义神经网络中的所有可训练参数;

    ​ (3) 优化这些可训练的参数,利用嵌套循环在 with 结构中求得损失函数 loss 对每个可训练参数的偏导数,更改这些可训练参数,为了查看效果,程序中可以加入每遍历一次数据集显示当前准确率,还可以画出准确率 acc 和损失函数 loss 的变化曲线图。

    4.3 程序

    ​ 以上部分的完整代码与解析如下:

    (1) 数据集读入

    from sklearn.datasets import datasets
    x_data = datasets.load_iris().data    # 返回 iris 数据集所有输入特征
    y_data = datasets.load_iris().target   # 返回 iris 数据集所有标签
    

    (2) 数据集乱序

    np.random.seed(116)    # 使用相同的 seed,使输入特征/标签一一对应
    np.random.shuffle(x_data)
    np.random.seed(116) 
    np.random.shuffle(y_data) tf.random.set_seed(116)
    

    (3) 数据集分割成永不相见的训练集和测试集:

    x_train = x_data[:-30] 
    y_train = y_data[:-30] 
    x_test = x_data[-30:] 
    y_test = y_data[-30:]
    

    (4) 配成[输入特征,标签]对,每次喂入一小撮(batch):

    train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32) 
    test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
    

    ​ 上述四小部分代码实现了数据集读入、数据集乱序、将数据集分割成永不相见的训练集和测试集、将数据配成[输入特征,标签]对。人类在认识这个世界的时候信息是没有规律的,杂乱无章的涌入大脑的,所以喂入神经网络的数据集也需要被打乱顺序。(2)部分实现了让数据集乱序,因为使用了同样的随机种子,所以打乱顺序后输入特征和标签仍然是一一对应的。(3)部分将打乱后的前 120 个数据取出来作为训练集,后 30 个数据作为测试集,为了公正评判神经网络的效果, 训练集和测试集没有交集。(4)部分使用 from_tensor_slices 把训练集的输入特征和标签配对打包,将每 32 组输入特征标签对打包为一个 batch,在喂入神经网络时会以 batch 为单位喂入。

    (5) 定义神经网路中所有可训练参数:

    w1 = tf.Variable(tf.random.truncated_normal([ 4,  3 ], stddev=0.1, seed=1)) 
    pyb1 = tf.Variable(tf.random.truncated_normal([ 3 ], stddev=0.1, seed=1))
    

    (6) 嵌套循环迭代,with 结构更新参数,显示当前 loss:

    for epoch in range(epoch):  #数据集级别迭代
    	for step, (x_train, y_train) in enumerate(train_db):  #batch 级别迭代
    		with tf.GradientTape() as tape:  # 记录梯度信息
    			(前向传播过程计算 y)
                (计算总 loss)
    		grads = tape.gradient(loss, [ w1, b1 ]) 
    		w1.assign_sub(lr * grads[0])  #参数自更新
    		b1.assign_sub(lr * grads[1])
    print("Epoch {}, loss: {}".format(epoch, loss_all/4))
    

    ​ 上述两部分完成了定义神经网路中所有可训练参数、嵌套循环迭代更新参数。(5) 部分定义了神经网络的所有可训练参数。只用了一层网络,因为输入特征是 4个,输出节点数等于分类数,是 3 分类,故参数w1 为 4 行 3 列的张量,b1 必须与w1 的维度一致,所以是 3。(6)部分用两层 for 循环进行更新参数:第一层 for 循环是针对整个数据集进行循环,故用 epoch 表示;第二层 for 循环是针对 batch 的,用 step 表示。在 with 结构中计算前向传播的预测结果 y ,计算损失函数 loss失函数 loss,分别对参数 w1 和参数b1 计算偏导数,更新参数 w1 和参数b1 的值,打印出这一轮 epoch 后的损失函数值。因为训练集有 120 组数据,batch 是 32, 每个 step 只能喂入 32 组数据,需要 batch 级别循环 4 次,所以 loss 除以 4,求得每次 step 迭代的平均 loss。

    ​ (7) 计算当前参数前向传播后的准确率,显示当前准确率 acc:

    for x_test, y_test in test_db:
    	y = tf.matmul(h, w) + b # y 为预测结果
    	y = tf.nn.softmax(y)       # y 符合概率分布
    	pred = tf.argmax(y, axis=1) # 返回 y 中最大值的索引即预测的分类
    	pred = tf.cast(pred, dtype=y_test.dtype) # 调整数据类型与标签一致
    	correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
    	correct = tf.reduce_sum (correct) # 将每个 batch 的 correct 数加起来
    	total_correct += int (correct) # 将所有 batch 中的 correct 数加起来
    		total_number += x_test.shape [0] 
    	acc = total_correct / total_number 
    	print("test_acc:", acc)
    

    ​ (8) acc / loss 可视化:

    plt.title('Acc Curve')   # 图片标题
    plt.xlabel('Epoch') # x 轴名称
    plt.ylabel('Acc')    # y 轴名称
    plt.plot(test_acc, label="$Accuracy$")   # 逐点画出 
    test_acc 值并连线
    plt.legend() 
    plt.show()
    

    上述两部分完成了对准确率的计算并可视化准确率与 loss。(7)部分前向传播计算出 y ,使其符合概率分布并找到最大的概率值对应的索引号,调整数据类型与标 签一致,如果预测值和标签相等则 correct 变量自加一,准确率即预测对了的数量除以测试集中的数据总数。
    (9)部分可将计算出的准确率画成曲线图,通过设置图标题、设置 x 轴名称、设置 y 轴名称,标出每个 epoch 时的准确率并画出曲线,可用同样方法画出 loss 曲线。结果图如图 4.1 与 4.2。

    image-20210622002909481

    4.1 训练过程 loss 曲线

    image-20210622003014045

    图4.2 训练过程准确率曲线

  • 相关阅读:
    jQueryfocus,title,振动
    使用jQuery自动缩图片 (转载)
    jQuery10个小例子(jquery之旅).
    jQuery动态增加删除Tabs
    jQuery图片播放轮换
    jQuery插件上传控件美化
    Ajax简单
    jQuery仿QQ改版后的样式切换
    jQuery插件tooltip(超链接提示,图片提示).
    css分页样式
  • 原文地址:https://www.cnblogs.com/wind-and-sky/p/14916626.html
Copyright © 2011-2022 走看看