zoukankan      html  css  js  c++  java
  • tensorflow之MLP学习

    这里主要将《Tensorflow实战》第4.4节多层感知机的代码扩展到了两个隐含层

    但是出现了一个问题:为什么识别率只有11.35%呢?  一个隐含层是98%~~

    更改学习率 迭代batch数,还有隐含层节点数之后正常了。

    但是识别率并没有明显提升

    代码如下

    # -*- coding: utf-8 -*-
    """
    Created on Fri Dec 15 16:21:21 2017

    @author: Administrator
    """

    #%%
    # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================

    # Create the model
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    sess = tf.InteractiveSession()

    in_units = 784
    h1_units = 500
    h2_units = 100
    W1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
    b1 = tf.Variable(tf.zeros([h1_units]))
    W2 = tf.Variable(tf.truncated_normal([h1_units, h2_units], stddev=0.1))
    b2 = tf.Variable(tf.zeros([h2_units]))
    W3 = tf.Variable(tf.zeros([h2_units, 10]))
    b3 = tf.Variable(tf.zeros([10]))

    x = tf.placeholder(tf.float32, [None, in_units])
    keep_prob = tf.placeholder(tf.float32)

    hidden1 = tf.nn.relu(tf.matmul(x, W1) + b1)
    hidden1_drop = tf.nn.dropout(hidden1, keep_prob)
    hidden2 = tf.nn.relu(tf.matmul(hidden1_drop, W2) + b2)
    hidden2_drop = tf.nn.dropout(hidden2, keep_prob)
    y = tf.nn.softmax(tf.matmul(hidden2_drop, W3) + b3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy)

    # Train
    tf.global_variables_initializer().run()
    for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(1000)
    train_step.run({x: batch_xs, y_: batch_ys, keep_prob: 1.0})

    # Test trained model
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

      

  • 相关阅读:
    4408: [Fjoi 2016]神秘数
    UOJ #35. 后缀排序[后缀数组详细整理]
    POJ 2887 Big String
    搜索过滤grep(win下为findstr)
    解决putty自动断开的问题
    > >> 将错误输出到文件
    环境变量
    端口被占用,查看并杀死占用端口的进程
    查找文件路径find
    【vim使用】
  • 原文地址:https://www.cnblogs.com/Jerry-PR/p/8043758.html
Copyright © 2011-2022 走看看