zoukankan      html  css  js  c++  java
  • tensorflow入门02——搭建逻辑回归模型

    搭建逻辑回归模型对mnist数据集中手写字体进行分类

    1、加载数据

    minst=input_data.read_data_sets('data/',one_hot=True)
    trainimg=minst.train.images
    trainlabel=minst.train.labels
    testimg=minst.test.images
    testlabel=minst.test.labels
    print('minst loaded')
    

    2、设置变量

    x=tf.placeholder("float",[None,784])
    y=tf.placeholder("float",[None,10])
    W=tf.Variable(tf.zeros([784,10]))
    b=tf.Variable(tf.zeros([10]))

    3、回归模型

    这里回归模型用softmax做十分类任务

    #logistic regression model
    actv=tf.nn.softmax(tf.matmul(x,W)+b)
    #cost
    cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
    #optimizer
    learning_rate=0.01
    optm=tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
    #prediction
    pred=tf.equal(tf.argmax(actv,1), tf.argmax(y,1))
    #accuracy
    accr=tf.reduce_mean(tf.cast(pred,"float"))

    4、开始训练

    #initializer
    init=tf.global_variables_initializer()
    sess=tf.InteractiveSession()
    
    training_epochs=50
    batch_size=100
    display_step=5
    
    sess=tf.Session()
    sess.run(init)
    
    for epoch in range(training_epochs):
        avg_cost=0.
        num_batch=int(minst.train.num_examples/batch_size)
        for i in range(num_batch):
            batch_xs,batch_ys=minst.train.next_batch(batch_size)
            feeds={x:batch_xs,y:batch_ys}
            sess.run(optm,feed_dict=feeds)
            avg_cost+=sess.run(cost,feeds)/num_batch
            #display
            if epoch%display_step==0:
                feeds_train={x:batch_xs,y:batch_ys}
                feeds_test={x:minst.test.images,y:minst.test.labels}
                train_acc=sess.run(accr,feed_dict=feeds_train)
                test_acc=sess.run(accr,feed_dict=feeds_test)
                print("Epoch:%03d/%03d cost:%.9f train_accr:%.3f test_accr:%.3f"%(epoch,training_epochs,avg_cost,train_acc,test_acc))
    print("Done")

    5、训练结果

  • 相关阅读:
    LeetCode24-Swap_Pairs
    LeeCode
    LeetCode3-Longest_Substring_Without_Repeating_Characters
    治愈 JavaScript 疲态的学习计划【转载】
    前端冷知识集锦[转载]
    知道这20个正则表达式,能让你少写1,000行代码[转载]
    关于简历和面试【整理自知乎】
    正念冥想方法
    一些职场经验【转载自知乎】
    犹太复国计划向世界展现了一个不一样的民族——观《犹太复国血泪史》有感
  • 原文地址:https://www.cnblogs.com/XiaoGao128/p/14276060.html
Copyright © 2011-2022 走看看