zoukankan      html  css  js  c++  java
  • 用tensorflow的Eager执行模式

    一、即时执行模式

    import tensorflow as tf
    import tensorflow.contrib.eager as tfe
    tfe.enable_eager_execution()

    a = tf.constant(12)
    counter = 0
    while not tf.equal(a, 1):
    if tf.equal(a % 2, 0):
    a = a / 2
    else:
    a = 3 * a + 1
    print(a)

     二、用Eager执行模式的MNIST模型构建

    import tensorflow as tf
    import tensorflow.contrib.eager as tfe
    tfe.enable_eager_execution()
    from tensorflow.examples.tutorials.mnist import input_data

    class MNIST:
    def __init__(self):
    self.mnist = input_data.read_data_sets("data/MNIST_data/", one_hot=True)
    self.train_ds = tf.data.Dataset.from_tensor_slices((self.mnist.train.images, self.mnist.train.labels))
    .map(lambda x, y: (x, tf.cast(y,tf.float32))).shuffle(buffer_size=1000).batch(100)
    self.W = tf.get_variable(name="W", shape=(784, 10)
    self.b = tf.get_variable(name="b", shape=(10, ))

    def softmax_model(self,image_batch):
    model_output = tf.nn.softmax(tf.matmul(image_batch, self.W) + self.b)
    return model_output

    def cross_entropy(self, model_output,label_batch):
    loss = tf.reduce_mean(tf.reduce_sum(label_batch * tf.log(model_output), axis=[1]))
    return loss

    @tfe.implicit_value_and_gradients
    def cal_gradient(self, image_batch, label_batch):
    return self.cross_entropy(self.softmax_model(image_batch), label_batch)

    def train(self):
    optimizer = tf.train.GradientDescentOptimizer(0.5)

    for step, (image_batch, label_batch) in enumerate(tfe.Iterator(self.train_ds)):
    loss, grads_and_vars = self.cal_gradient(image_batch, label_batch)
    optimizer.apply_gradients(grads_and_vars)
    print("step: { } loss: { }".format(step, loss.numpy()))

    def evaluate(self):
    model_test_output = self.softmax_model(self.mnist.test.images)
    model_test_label = self.mnist.test.labels
    correct_prediction = tf.equal(tf.argmax(model_test_output, 1), tf.argmax(model_test_label, 1))
    self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print("test accuracy = { }".format(self.accuracy.numpy()))

    if __name__ == '__main__':
    mnist_model = MNIST()
    mnist_model.train()
    mnist_model.evaluate()

  • 相关阅读:
    faster rcnn学习(三)
    too many values to unpack (expected 2)
    RuntimeWarning: overflow encountered in ubyte_scalars
    C#中excel读取和写入
    C#中使用Sql对Excel条件查询
    IIS上部署MVC网站,打开后ExtensionlessUrlHandler-Integrated-4.0解决方法IIS上部署MVC网站,打开后500错误
    C#微信公众平台账号开发,从零到整,步骤详细。
    VS快捷键大全
    ASP.NET将文件写到另一服务器
    开放api接口签名验证
  • 原文地址:https://www.cnblogs.com/zhaop8078/p/9571861.html
Copyright © 2011-2022 走看看