zoukankan      html  css  js  c++  java
  • tensorflow入门02——搭建逻辑回归模型

    搭建逻辑回归模型对mnist数据集中手写字体进行分类

    1、加载数据

    minst=input_data.read_data_sets('data/',one_hot=True)
    trainimg=minst.train.images
    trainlabel=minst.train.labels
    testimg=minst.test.images
    testlabel=minst.test.labels
    print('minst loaded')
    

    2、设置变量

    x=tf.placeholder("float",[None,784])
    y=tf.placeholder("float",[None,10])
    W=tf.Variable(tf.zeros([784,10]))
    b=tf.Variable(tf.zeros([10]))

    3、回归模型

    这里回归模型用softmax做十分类任务

    #logistic regression model
    actv=tf.nn.softmax(tf.matmul(x,W)+b)
    #cost
    cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
    #optimizer
    learning_rate=0.01
    optm=tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
    #prediction
    pred=tf.equal(tf.argmax(actv,1), tf.argmax(y,1))
    #accuracy
    accr=tf.reduce_mean(tf.cast(pred,"float"))

    4、开始训练

    #initializer
    init=tf.global_variables_initializer()
    sess=tf.InteractiveSession()
    
    training_epochs=50
    batch_size=100
    display_step=5
    
    sess=tf.Session()
    sess.run(init)
    
    for epoch in range(training_epochs):
        avg_cost=0.
        num_batch=int(minst.train.num_examples/batch_size)
        for i in range(num_batch):
            batch_xs,batch_ys=minst.train.next_batch(batch_size)
            feeds={x:batch_xs,y:batch_ys}
            sess.run(optm,feed_dict=feeds)
            avg_cost+=sess.run(cost,feeds)/num_batch
            #display
            if epoch%display_step==0:
                feeds_train={x:batch_xs,y:batch_ys}
                feeds_test={x:minst.test.images,y:minst.test.labels}
                train_acc=sess.run(accr,feed_dict=feeds_train)
                test_acc=sess.run(accr,feed_dict=feeds_test)
                print("Epoch:%03d/%03d cost:%.9f train_accr:%.3f test_accr:%.3f"%(epoch,training_epochs,avg_cost,train_acc,test_acc))
    print("Done")

    5、训练结果

  • 相关阅读:
    STS 配置tomcat以后,无法访问
    docker
    Java
    STS
    Java
    docker
    sql产生随机时间
    sql产生随机数
    Android 代码自动提示功能
    Activity的跳转与传值
  • 原文地址:https://www.cnblogs.com/XiaoGao128/p/14276060.html
Copyright © 2011-2022 走看看