zoukankan      html  css  js  c++  java
  • 寒假学习日报(四十二)——Tensorflow实验

      今天把老师布置的tensorflow实验做了做,由于下载的是2.0版本,不少1.0版本函数调用都需要.compat.v1前缀,还要加入下面这个代码保证运行:

    # 保证session.run能够正常运行
    tf.compat.v1.disable_eager_execution()

      七个实验除了最后一个都已完成:

       解决的Bug就是本地MNIST数据集的导入问题,导入函数如下所示:

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_Data", one_hot=True)

    这个是1.0版本的写法,2.0版本开始都使用keras模块了,但keras模块拿到数据后需要手动进行one_hot处理,我不清楚1.0版本中one_hot=True是咋处理的,因此废了一门心思去引入tutorials包。主要处理的BUG是一开始引入MNIST——Data总是报错,最后我把数据集删除后重新下了一个新的,引入成功。。。。

      最后就是问题代码:

    from __future__ import print_function
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import os
    
    # 保证session.run能够正常运行
    tf.compat.v1.disable_eager_execution()
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    mnist = input_data.read_data_sets("MNIST_Data", one_hot=True)
    # 设置训练参数
    learning_rate = 0.001
    training_step = 10000
    batch_size = 128
    display_step = 400
    # 设置双向循环神经网络参数
    num_input = 28
    timestep = 28
    num_hidden = 128
    num_classes = 10
    # 构造计算图输入变量
    X = tf.compat.v1.placeholder("float32", [None, timestep, num_input])
    Y = tf.compat.v1.placeholder("float32", [None, num_classes])
    # 设置权重和偏值
    weights = {
        'out': tf.Variable(tf.compat.v1.random_normal([2*num_hidden, num_classes]))
    }
    biases = {
        'out': tf.Variable(tf.compat.v1.random_normal([num_classes]))
    }
    
    
    # 自定义双向循环神经网络函数
    def BiRNN(X, weights, biases):
        x = tf.unstack(X, timestep, 1)
        lstm_fw_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_hidden, forget_bias=1.0)
        lstm_bw_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_hidden, forget_bias=1.0)
        try:
            outputs, _, _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
        except Exception:
            outputs = tf.compat.v1.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
        return tf.matmul(outputs[-1], weights['out']) + biases['out']
    
    
    # 获取神经网络输出层logits,使用softmax激活函数将logits映射成各类取值概率,结果赋值给prediction
    logits = BiRNN(X, weights, biases)
    prediction = tf.nn.softmax(logits)
    # 通过交叉熵构建损失函数,并使用梯度下降法求解,将结果赋值给train_op。
    loss_op = tf.reduce_mean(tf.compat.v1.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss_op)
    # 计算训练模型的准确率
    correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    # 初始化全部变量
    init = tf.compat.v1.global_variables_initializer()
    # Session
    with tf.compat.v1.Session() as sess:
        # 训练模型
        for step in range(1, training_step+1):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            batch_x = batch_x.reshape((batch_size, timestep, num_input))
            sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
            if step % display_step == 0 or step == 1:
                loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y})
                print("Step"+str(step)+",Minbatch Loss="+"{:.4f}".format(loss)+",Training Accuracy="+"{:.3f}".format(acc))
        print("Optimization Finished!")
        # 设置测试数据长度为128,通过训练模型对测试数据进行预测,打印测试数据的准确率
        test_len = 128
        test_data = mnist.test.images[:test_len].reshape((-1, timestep, num_input))
        test_label = mnist.test.labels["test_len"]
        print("Test Accuracy:", sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

     主要报ValueError错误,目前还在排查中。。。。

  • 相关阅读:
    安卓执行机制JNI、Dalvik、ART之间的比較 。android L 改动执行机制。
    Android studio 导入githubproject
    JS创建对象几种不同方法具体解释
    python 学习笔记 13 -- 经常使用的时间模块之time
    Version和Build的差别
    关于Java基础的一些笔试题总结
    vim编码方式配置的学习和思考
    从头认识java-15.5 使用LinkedHashSet须要注意的地方
    一篇文章,带你明确什么是过拟合,欠拟合以及交叉验证
    Spring -- Bean自己主动装配&Bean之间关系&Bean的作用域
  • 原文地址:https://www.cnblogs.com/20183711PYD/p/14423126.html
Copyright © 2011-2022 走看看