zoukankan      html  css  js  c++  java
  • TensorFlow逻辑回归

    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
    #from tensorflow.examples.tutorials.mnist import input_data
    import input_data
    #导入实验所需的数据
    mnist = input_data.read_data_sets("D:大二Java大三寒假作业大三寒假作业深度学习算法部分",one_hot = True)
    #设置训练参数
    learning_rate=0.01
    training_epochs=25
    batch_size=100
    display_step=1
    
    #构造计算图,使用占位符placeholder函数构造变量x,y,
    x=tf.placeholder(tf.float32,[None,784])
    y=tf.placeholder(tf.float32,[None,10])
    #使用Variable函数,设置模型的初始权重
    W=tf.Variable(tf.zeros([784,10]))
    b=tf.Variable(tf.zeros([10]))
    #构造逻辑回归模型
    pred=tf.nn.softmax(tf.matmul(x,W)+b)
    #构造代价函数cost
    cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
    #使用梯度下降法求最小值,即最优解
    optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    #初始化全部变量
    init=tf.global_variables_initializer()
    #.使用tf.Session()创建Session会话对象,会话封装了Tensorflow运行时的状态和控制
    with tf.Session() as sess:
        sess.run(init)
        #调用会话对象sess的run方法,运行计算图,即开始训练模型
        for epoch in range(training_epochs):
            avg_cost = 0
            total_batch = int(mnist.train.num_examples / batch_size)
            for i in range(total_batch):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})
                avg_cost += c  / total_batch
            if (epoch+1) % display_step == 0:
                print("Epoch:", '%04d' % (epoch + 1), "Cost:","{:.09f}".format(avg_cost))
        print("Optimization Finished!")
        #测试模型
        correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
        #评估模型的准确度
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print("Accuracy:", accuracy.eval({x: mnist.test.images[:3000], y: mnist.test.labels[:3000]}))
    from tensorflow.examples.tutorials.mnist import input_data不能使用,由于我的D:anaconda3Libsite-packages	ensorflowexamples下缺少tutorials不能下载
    examples所有可以用import input_data代替。
    如果所给的文件出现问题是,需要重新下载四个文件

     

     

     
  • 相关阅读:
    命令拷屏之网络工具
    PHP 设计模式 笔记与总结(1)命名空间 与 类的自动载入
    Java实现 计蒜客 1251 仙岛求药
    Java实现 计蒜客 1251 仙岛求药
    Java实现 计蒜客 1251 仙岛求药
    Java实现 蓝桥杯 算法训练 字符串合并
    Java实现 蓝桥杯 算法训练 字符串合并
    Java实现 蓝桥杯 算法训练 字符串合并
    Java实现 LeetCode 143 重排链表
    Java实现 LeetCode 143 重排链表
  • 原文地址:https://www.cnblogs.com/1234yyf/p/14276301.html
Copyright © 2011-2022 走看看