zoukankan      html  css  js  c++  java
  • tensorflow dropout实现

    指定keep_prob即可,下面的例子使用了占位符。为了简便起见,直接给keep_prob赋一个定值可能更好,但占位符在每次运行时都可以指定keep_prob的值。

    keep_prob = tf.placeholder('float')
    
    L1 = ...
    
    L1_d = tf.nn.dropout(L1, keep_prob)
    
    # Train
    sess.run(optimizer, feed_dict={X: batch_xs, Y: batch_ys, keep_prob: 0.7})
    # Evaluation
    print("Accuracy", accuracy.eval({X: mnist.test.images, Y: mnist.test.labels, keep_prob: 1}))
    

    更详细的例子:

    # dropout (keep_prob) rate  0.7 on training, but should be 1 for testing
    keep_prob = tf.placeholder(tf.float32)
    
    W1 = tf.get_variable("W1", shape=[784, 512])
    b1 = tf.Variable(tf.random_normal([512]))
    L1 = tf.nn.relu(tf.matmul(X, W1) + b1)
    L1 = tf.nn.dropout(L1, keep_prob=keep_prob)
    
    W2 = tf.get_variable("W2", shape=[512, 512])
    b2 = tf.Variable(tf.random_normal([512]))
    L2 = tf.nn.relu(tf.matmul(L1, W2) + b2)
    L2 = tf.nn.dropout(L2, keep_prob=keep_prob)# train model
    for epoch in range(training_epochs):
        ...
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            feed_dict = {X: batch_xs, Y: batch_ys, keep_prob: 0.7}
            c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
            avg_cost += c / total_batch
    
    # Test model and check accuracy
    correct_prediction = tf.equal(tf.argmax(hypothesis, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accuracy:', sess.run(accuracy, feed_dict={
         X: mnist.test.images, Y: mnist.test.labels, keep_prob: 1}))
    
    
  • 相关阅读:
    生成1--n的全排列
    小P的秘籍
    小P的字符串
    小P的金字塔
    2198: 小P当志愿者送餐
    交换排序(快速排序重点)
    关于上级机构的冲突性测试bug修复
    系统当前时间system.currenttimemillis与new Date().getTime() 区别
    Servlet中(Session、cookies、servletcontext)的基本用法
    默认图片展示(个人信息模块)
  • 原文地址:https://www.cnblogs.com/wanghongze95/p/13842483.html
Copyright © 2011-2022 走看看