zoukankan      html  css  js  c++  java
  • 大三寒假学习进度(4)

    tensorflow学习

    • 鸢尾花分类

    步骤

    1 · 准备数据,包括数据集读入、数据集乱序,把训练集和测试集中的数据配成输入特征和标签对,生成 train 和 test 即永不相见的训练集和测试集;
    2 · 搭建网络,定义神经网络中的所有可训练参数;
    3 · 优化这些可训练的参数,利用嵌套循环在 with 结构中求得损失函数 loss对每个可训练参数的偏导数,更改这些可训练参数,为了查看效果,程序中可以加入每遍历一次数据集显示当前准确率,还 可以画出准确率 acc 和损失函数 loss的变化曲线图。

    代码实现

    from sklearn import datasets
    import tensorflow as tf
    import numpy as np
    from matplotlib import pyplot as plt
    #读入数据并进行数据分割处理
    x_data =datasets.load_iris().data
    y_data =datasets.load_iris().target
    
    np.random.seed(116)
    np.random.shuffle(x_data)
    np.random.seed(116)
    np.random.shuffle(y_data)
    tf.random.set_seed(116)
    
    x_train = x_data[:-30]
    y_train = y_data[:-30]
    x_test = x_data[-30:]
    y_test = y_data[-30:]
    
    x_train = tf.cast(x_train,tf.float32)
    x_test = tf.cast(x_test,tf.float32)
    
    train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
    test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
    # 设置参数
    w1 = tf.Variable(tf.random.truncated_normal([4,3],stddev =0.1,seed=1))
    b1 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))
    
    lr =0.1#学习率
    train_loss_results=[]
    test_acc=[]
    epoch = 500#循环次数
    loss_all = 0
    # 训练部分
    for epoch in range(epoch):
        for step,(x_train,y_train) in enumerate(train_db):
            with tf.GradientTape() as tape:
                y = tf.matmul(x_train, w1) + b1
                y = tf.nn.softmax(y)
                y_ = tf.one_hot(y_train, depth=3)
                loss = tf.reduce_mean(tf.square(y_ - y))
                loss_all += loss.numpy()
    
        grads = tape.gradient(loss,[w1,b1])
        # 更新参数
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
    
        print("Epoch {}, loss: {}".format(epoch, loss_all/4))
        train_loss_results.append(loss_all/4)
        loss_all=0
    
        total_correct, total_number = 0,0
        for x_test , y_test in test_db:
            y = tf.matmul(x_test,w1)+b1
            y =tf.nn.softmax(y)
            pred = tf.argmax(y,axis=1)
            pred = tf.cast(pred,dtype=y_test.dtype)
            correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
            correct = tf.reduce_sum(correct) # 将每个 batch 的 correct 数加起来
            total_correct += int(correct) # 将所有 batch 中的 correct 数加起来
            total_number += x_test.shape[0]
        acc = total_correct / total_number
        test_acc.append(acc)
        print("test_acc:", acc)
        print("--------------------------------")
    # 画图
    plt.title('Loss Function Curve')  # 图片标题
    plt.xlabel('Epoch')  # x 轴名称
    plt.ylabel('Loss')  # y 轴名称
    plt.plot(train_loss_results, label="$Loss$")  #
    plt.legend()
    plt.show()
    
    
    plt.title('Acc Curve') # 图片标题
    plt.xlabel('Epoch') # x 轴名称
    plt.ylabel('Acc') # y 轴名称
    plt.plot(test_acc, label="$Accuracy$") # 逐点画出 test_acc 值并连线
    plt.legend()
    plt.show()
    
    
    
    

    结果


  • 相关阅读:
    对象的访问定位——如何找到对象
    对象的结构
    对象在内存中的布局-对象的创建
    java的内存模型--jmm
    redis 持久化之rdb总结
    简单说springmvc的工作原理
    抽象类和接口的区别
    hashcode和equals的作用区别及联系
    DBC物品中打包物品参数设置
    关于GOM引擎启动时显示:windows socket error: 在其上下文中,该请求的地址无效。 (10049), on API 'bind'
  • 原文地址:https://www.cnblogs.com/--lzx1--/p/14304385.html
Copyright © 2011-2022 走看看