zoukankan      html  css  js  c++  java
  • 【tensorflow】mnist-精简版模型

    所有的ML模型或者DL 模型 都是下面这四个固定套路的步骤

    1.获取到所需数据

    2.开始搭建模型

    3.计算采用何种loss函数

    4.选择batch,epoch,feed数据 

    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    
    mnist = input_data.read_data_sets('./tmp/tensorflow/mnist/input_data',one_hot=True) # 下载数据
    x = tf.placeholder(tf.float32,[None,784]) # 输入占位符
    yresult = tf.placeholder(tf.float32,[None,10]) #输入数据真实的label
    
    w = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x,w) + b) # 用不用激励函数 都可以的其实
    cross_entropy = -tf.reduce_sum(yresult * tf.log(y)) # loss  值
    train_setp = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #梯度下降法
    init  = tf.initialize_all_variables()
    
    with tf.Session() as sess:
        sess.run(init)
        for i in range(1000):
            batch_xs, batch_ys = mnist.train.next_batch(100)
            argv1,loss = sess.run([train_setp,cross_entropy],feed_dict={x:batch_xs,yresult:batch_ys}) #如果想知道corss_entropy试试变化值 加入就好。
            if i % 200 == 0:
                print (loss)
    
        current_prediction = tf.equal(tf.argmax(y,1),tf.argmax(yresult,1)) # compare real and calculate
        accuracy = tf.reduce_mean(tf.cast(current_prediction,tf.float32))  # 数据类型转换 然后求匹配上的概率
        result = sess.run(accuracy,feed_dict={x:mnist.test.images,yresult:mnist.test.labels}) # test数据入口
        print(str(result * 100) + '%')
    
    关注公众号 海量干货等你
  • 相关阅读:
    eureka注册中心搭建
    MySQL基本查询语句
    通过IP地址和子网掩码与运算计算相关地址
    Linux bash 快捷键列表【转】
    Linux 系统简介【转】
    Linux Vim -d 比较两个文件
    无连接和面向连接协议的区别【转】
    ARP报文详解
    Linux 指定网卡 ping
    Linux Windows Java 快速生成指定大小的空文件
  • 原文地址:https://www.cnblogs.com/sowhat1412/p/12734364.html
Copyright © 2011-2022 走看看