zoukankan      html  css  js  c++  java
  • tensorflow2.0——手写数据集预测(多元逻辑回归)

    import tensorflow as tf
    import numpy as np
    import matplotlib.pylab as plt
    
    plt.rcParams["font.family"] = 'SimHei'                          # 将字体改为中文
    plt.rcParams['axes.unicode_minus'] = False                      # 设置了中文字体默认后,坐标的"-"号无法显示,设置这个参数就可以避免
    
    # 加载手写数字数据
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    
    #   将0到9转化为one-hot编码
    y_hot = np.zeros((10, 10))
    for i in range(y_hot.shape[0]):
        y_hot[i, i] = 1
    # print('y_hot:', y_hot)
    #   将标记值转化为one-hot编码
    train_Y = np.zeros((train_y.shape[0], 10))
    for i in range(train_y.shape[0]):
        train_Y[i] = y_hot[train_y[i]]
    print('train_Y:', train_Y, train_Y.shape)
    
    #   将28*28展开为784*1
    #   训练集
    train_X1 = np.ones((train_x.shape[0], 784))
    ones = np.ones((train_x.shape[0], 1))
    print('ones.shape:', ones.shape)
    for i in range(train_x.shape[0]):
        train_X1[i] = train_x[i].reshape([1, -1])
    print('train_X1.shape:', train_X1.shape)
    train_X = tf.concat([train_X1, ones], axis=1)
    #   测试集
    test_X1 = np.ones((test_x.shape[0], 784))
    ones = np.ones((test_x.shape[0], 1))
    for i in range(test_x.shape[0]):
        test_X1[i] = test_x[i].reshape([1, -1])
    test_X = tf.concat([test_X1, ones], axis=1)
    #   将标记数据转化为列向量
    train_y = train_y.reshape(-1,1)
    test_y = test_y.reshape(-1,1)
    #   存储准确值数据
    acc_train = []
    acc_test = []
    #   设置超参数
    iter = 1500                 #   迭代次数
    learn_rate = 5e-12          #   学习率
    #   初始化训练参数
    w = tf.Variable(np.random.randn(785, 10)*0.0001)
    print('初试w:',w,w.shape)
    for i in range(iter):
        with tf.GradientTape() as tape:
            y_p = 1/(1+tf.math.exp(-tf.matmul(train_X,w)))
            y_p_test = 1 / (1 + tf.math.exp(-tf.matmul(test_X, w)))
            loss = tf.reduce_sum(-(train_Y * tf.math.log(y_p)+(1 - train_Y)*tf.math.log(1-y_p)))
            # print('loss:',loss)
        dl_dw = tape.gradient(loss,w)
        w.assign_sub(learn_rate * dl_dw)
        if i % 20 == 0:
            print('i:{}, loss:{}, w:{}'.format(i,loss,w))
            # print('y_p:',y_p)
            #   训练集准确率
            y_p_round = tf.round(y_p)                                           #   将预测数据进行四舍五入变成one-hot编码格式
            p_y = tf.reshape(tf.argmax(y_p_round, 1), (-1, 1))                  #   将one-hot转化为预测数字
            is_right = tf.equal(p_y, train_y)                                   #   比对是否预测正确
            right_int = tf.cast(is_right, tf.int8)                              #   将bool型转化为0,1
            acc = tf.reduce_mean(tf.cast(right_int, dtype=tf.float32))          #   求准确数组的平均值,也就是准确率
            acc_train.append(acc)
            print('acc:', acc)
            #   测试集准确率
            y_p_test_round = tf.round(y_p_test)
            p_y_test = tf.reshape(tf.argmax(y_p_test_round, 1), (-1, 1))
            is_right_test = tf.equal(p_y_test, test_y)
            right_int_test = tf.cast(is_right_test, tf.int8)
            acc2 = tf.reduce_mean(tf.cast(right_int_test, dtype=tf.float32))
            acc_test.append(acc2)
            print('acc2:', acc2)
            print()
    
    #   画出准确率的训练折线图
    plt.plot(acc_train,label = '训练集正确率')
    plt.plot(acc_test,label = '测试集正确率')
    plt.legend()
    plt.show()

  • 相关阅读:
    centos6.5 系统乱码解决 i18n --摘自http://blog.csdn.net/yangkai_hudong/article/details/19033393
    openssl pem转cer
    nginx 重装添加http_ssl_module模块
    ios 利用airprint实现无线打印(配合普通打印机)
    centos nginx server_name 配置域名访问规则
    MySQL Innodb数据库性能实践——热点数据性能
    jQuery中的DOM操作
    C++函数学习笔记
    jQuery选择器容易忽视的小知识大问题
    写给自己的话
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13462119.html
Copyright © 2011-2022 走看看