zoukankan      html  css  js  c++  java
  • 学习进度笔记16

    今天通过观看老师分享的TensorFlow教学视频,完成了逻辑回归模型的创建,内容比较复杂代码也比较陌生,理解起来非常费力。

    具体代码如下:

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data

    mnist = input_data.read_data_sets('data/',one_hot=True)
    #print
    trainimg = mnist.train.images
    trainlabel = mnist.train.labels
    testimg = mnist.test.images
    testlabel = mnist.test.labels
    print("trainlabel:",type(trainlabel),"shape:",trainlabel.shape)
    print("trainimg:",type(trainimg),"shape:",trainimg.shape)
    print("testlabel:",type(testlabel),"shape:",testlabel.shape)
    print("testimg:",type(testimg),"shape:",testimg.shape)
    print(trainlabel[0])

    x = tf.placeholder("float",[None,784])
    y = tf.placeholder("float",[None,10])
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    #计算模型
    actv = tf.nn.softmax(tf.matmul(x,W) + b)
    #损失
    cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
    #学习参数
    learning_rate = 0.01
    #梯度下降优化器
    optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    #对比预测值的索引和真实值的索引
    pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
    #把true或false转换成1或0求均值作为精度
    accr = tf.reduce_mean(tf.cast(pred,"float"))

    init = tf.global_variables_initializer()
    training_epochs = 50
    batch_size = 100
    display_step = 5
    sess = tf.Session()
    sess.run(init)
    for epoch in range(training_epochs):
    avg_cost = 0
    num_batch = int(mnist.train.num_examples/batch_size)
    for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict=feeds)/num_batch
    if epoch % display_step ==0:
    feed_train = {x:batch_xs,y:batch_ys}
    feed_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict=feed_train)
    test_acc = sess.run(accr,feed_dict=feed_test)
    print("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
    % (epoch,training_epochs,avg_cost,train_acc,test_acc))
    print("Done")
    运行截图:

  • 相关阅读:
    CKEditor4x word导入不保存格式的解决方案
    为了希望正式开始开发
    HTTP权威指南-URL与资源
    HTTP权威指南-HTTP概述
    同源策略和跨域访问
    普通Html文件图片上传(Storing Image To DB)
    PostgreSQL时间戳提取的特殊需求
    (转)百度前端规范、腾讯前端规范
    整理一下嵌入式WEB开发中的各种屏蔽(转)
    Excel表格指定列数据转换成文本
  • 原文地址:https://www.cnblogs.com/lijiawei1-2-3/p/14307191.html
Copyright © 2011-2022 走看看