zoukankan      html  css  js  c++  java
  • tensorflow入门02——搭建逻辑回归模型

    搭建逻辑回归模型对mnist数据集中手写字体进行分类

    1、加载数据

    minst=input_data.read_data_sets('data/',one_hot=True)
    trainimg=minst.train.images
    trainlabel=minst.train.labels
    testimg=minst.test.images
    testlabel=minst.test.labels
    print('minst loaded')
    

    2、设置变量

    x=tf.placeholder("float",[None,784])
    y=tf.placeholder("float",[None,10])
    W=tf.Variable(tf.zeros([784,10]))
    b=tf.Variable(tf.zeros([10]))

    3、回归模型

    这里回归模型用softmax做十分类任务

    #logistic regression model
    actv=tf.nn.softmax(tf.matmul(x,W)+b)
    #cost
    cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
    #optimizer
    learning_rate=0.01
    optm=tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
    #prediction
    pred=tf.equal(tf.argmax(actv,1), tf.argmax(y,1))
    #accuracy
    accr=tf.reduce_mean(tf.cast(pred,"float"))

    4、开始训练

    #initializer
    init=tf.global_variables_initializer()
    sess=tf.InteractiveSession()
    
    training_epochs=50
    batch_size=100
    display_step=5
    
    sess=tf.Session()
    sess.run(init)
    
    for epoch in range(training_epochs):
        avg_cost=0.
        num_batch=int(minst.train.num_examples/batch_size)
        for i in range(num_batch):
            batch_xs,batch_ys=minst.train.next_batch(batch_size)
            feeds={x:batch_xs,y:batch_ys}
            sess.run(optm,feed_dict=feeds)
            avg_cost+=sess.run(cost,feeds)/num_batch
            #display
            if epoch%display_step==0:
                feeds_train={x:batch_xs,y:batch_ys}
                feeds_test={x:minst.test.images,y:minst.test.labels}
                train_acc=sess.run(accr,feed_dict=feeds_train)
                test_acc=sess.run(accr,feed_dict=feeds_test)
                print("Epoch:%03d/%03d cost:%.9f train_accr:%.3f test_accr:%.3f"%(epoch,training_epochs,avg_cost,train_acc,test_acc))
    print("Done")

    5、训练结果

  • 相关阅读:
    根据出生日期来计算年龄
    tomcat 7 7.0.73 url 参数 大括号 {} 不支持 , 7.0.67支持
    hdu 1272(并查集)
    hdu 1558(计算几何+并查集)
    hdu 1856(hash+启发式并查集)
    hdu 1598(最小生成树)
    poj 3164(最小树形图模板)
    hdu 2489(状态压缩+最小生成树)
    hdu 3371(启发式合并的最小生成树)
    hdu 1301(最小生成树)
  • 原文地址:https://www.cnblogs.com/XiaoGao128/p/14276060.html
Copyright © 2011-2022 走看看