zoukankan      html  css  js  c++  java
  • (原)tensorflow使用eager在mnist上训练的简单例子

    转载请注明出处:

    https://www.cnblogs.com/darkknightzh/p/9989586.html

    代码网址:

    https://github.com/darkknightzh/trainEagerMnist

    参考网址:

    https://github.com/tensorflow/models/blob/master/official/mnist/mnist_eager.py

    https://github.com/madalinabuzau/tensorflow-eager-tutorials/blob/master/07_convolutional_neural_networks_for_emotion_recognition.ipynb

    总体流程

    tensorflow使用eager时,需要下面几句话(如果不使用第三句话,则依旧可以使用静态图):

    import tensorflow as tf
    import tensorflow.contrib.eager as tfe
    tfe.enable_eager_execution()

    tensorflow使用eager模式后,感觉和pytorch一样方便。使用eager后,不需要tf.placeholder,用起来更加方便。

    目前貌似tf.keras.layers和tf.layers支持eager,slim不支持。

    总体流程如下:

    initial optimizer
    for I in range(epochs):
        for imgs, targets in training_data:
            with tf.GradientTape() as tape:
                logits = model(imgs, training=True)
                loss_value = calc_loss(logits, targets)
            grads = tape.gradient(loss_value, model.variables)
            optimizer.apply_gradients(zip(grads, model.variables), global_step=step_counter)
            update training_accurate, total_loss
        test model
        save model

    创建模型

    可以使用下面三种方式创建模型

    1. 类似pytorch的方式

    先在__init__中定义用到的层,然后重载call函数,构建网络。模型前向计算时,会调用call函数。如下面代码所示:

     1 class simpleModel(tf.keras.Model):
     2     def __init__(self, num_classes):
     3         super(simpleModel, self).__init__()
     4 
     5         input_shape = [28, 28, 1]
     6         data_format = 'channels_last'
     7         self.reshape = tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(input_shape[0] * input_shape[1],))
     8 
     9         self.conv1 = tf.keras.layers.Conv2D(16, 5, padding="same", activation='relu')
    10         self.batch1 = tf.keras.layers.BatchNormalization()
    11         self.pool1 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
    12 
    13         self.conv2 = tf.keras.layers.Conv2D(32, 5, padding="same", activation='relu')
    14         self.batch2 = tf.keras.layers.BatchNormalization()
    15         self.pool2 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
    16 
    17         self.conv3 = tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu')
    18         self.batch3 = tf.keras.layers.BatchNormalization()
    19         self.pool3 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
    20 
    21         self.conv4 = tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu')
    22         self.batch4 = tf.keras.layers.BatchNormalization()
    23         self.pool4 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
    24 
    25         self.flat = tf.keras.layers.Flatten()
    26         self.fc5 = tf.keras.layers.Dense(1024, activation='relu')
    27         self.batch5 = tf.keras.layers.BatchNormalization()
    28 
    29         self.fc6 = tf.keras.layers.Dense(num_classes)
    30         self.batch6 = tf.keras.layers.BatchNormalization()
    31 
    32     def call(self, inputs, training=None):
    33         x = self.reshape(inputs)
    34 
    35         x = self.conv1(x)
    36         x = self.batch1(x, training=training)
    37         x = self.pool1(x)
    38 
    39         x = self.conv2(x)
    40         x = self.batch2(x, training=training)
    41         x = self.pool2(x)
    42 
    43         x = self.conv3(x)
    44         x = self.batch3(x, training=training)
    45         x = self.pool3(x)
    46 
    47         x = self.conv4(x)
    48         x = self.batch4(x, training=training)
    49         x = self.pool4(x)
    50 
    51         x = self.flat(x)
    52         x = self.fc5(x)
    53         x = self.batch5(x, training=training)
    54 
    55         x = self.fc6(x)
    56         x = self.batch6(x, training=training)
    57         # x = tf.layers.dropout(x, rate=0.3, training=training)
    58         return x
    59 
    60     def get_acc(self, target):
    61         correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(target, 1))
    62         acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    63         return acc
    64 
    65     def get_loss(self):
    66         return self.loss
    67 
    68     def loss_fn(self, images, target, training):
    69         self.logits = self(images, training)  # call call(self, inputs, training=None) function
    70         self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=target))
    71         return self.loss
    72 
    73     def grads_fn(self, images, target, training):  # do not return loss and acc if unnecessary
    74         with tfe.GradientTape() as tape:
    75             loss = self.loss_fn(images, target, training)
    76         return tape.gradient(loss, self.variables)

    2. 直接使用tf.keras.Sequential

    如下面代码所示:

     1 def create_model1():
     2     data_format = 'channels_last'
     3     input_shape = [28, 28, 1]
     4     l = tf.keras.layers
     5     max_pool = l.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
     6     # The model consists of a sequential chain of layers, so tf.keras.Sequential (a subclass of tf.keras.Model) makes for a compact description.
     7     return tf.keras.Sequential(
     8         [
     9             l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
    10             l.Conv2D(16, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
    11             l.BatchNormalization(),
    12             max_pool,
    13 
    14             l.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
    15             l.BatchNormalization(),
    16             max_pool,
    17 
    18             l.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
    19             l.BatchNormalization(),
    20             max_pool,
    21 
    22             l.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
    23             l.BatchNormalization(),
    24             max_pool,
    25 
    26             l.Flatten(),
    27             l.Dense(1024, activation=tf.nn.relu),
    28             l.BatchNormalization(),
    29 
    30             # # l.Dropout(0.4),
    31             l.Dense(10),
    32             l.BatchNormalization()
    33         ])

    3. 使用tf.keras.Sequential()及add函数

    如下面代码所示:

     1 def create_model2():
     2     data_format = 'channels_last'
     3     input_shape = [28, 28, 1]
     4 
     5     model = tf.keras.Sequential()
     6 
     7     model.add(tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(input_shape[0] * input_shape[1],)))
     8 
     9     model.add(tf.keras.layers.Conv2D(16, 5, padding="same", activation='relu'))
    10     model.add(tf.keras.layers.BatchNormalization())
    11     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))
    12 
    13     model.add(tf.keras.layers.Conv2D(32, 5, padding="same", activation='relu'))
    14     model.add(tf.keras.layers.BatchNormalization())
    15     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))
    16 
    17     model.add(tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu'))
    18     model.add(tf.keras.layers.BatchNormalization())
    19     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))
    20 
    21     model.add(tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu'))
    22     model.add(tf.keras.layers.BatchNormalization())
    23     model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format))
    24 
    25     model.add(tf.keras.layers.Flatten())
    26     model.add(tf.keras.layers.Dense(1024, activation='relu'))
    27     model.add(tf.keras.layers.BatchNormalization())
    28 
    29     model.add(tf.keras.layers.Dense(10))
    30     model.add(tf.keras.layers.BatchNormalization())
    31 
    32 return model

    使用动态图更新梯度

    在更新梯度时,需要加上下面的几句话

    1 with tf.GradientTape() as tape:
    2     logits = model(imgs, training=True)
    3     loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labs))
    4 grads = tape.gradient(loss_value, model.variables)
    5 optimizer.apply_gradients(zip(grads, model.variables), global_step=step_counter)

    第二行得到特征,第三行得到损失,第四行得到梯度,第五行将梯度应用到模型,更新模型参数。

    保存及载入模型

    1. 使用tfe.Saver

    代码如下

    1 def saveModelV1(model_dir, model, global_step, modelname='model1'):
    2     tfe.Saver(model.variables).save(os.path.join(model_dir, modelname), global_step=global_step)
    3 def restoreModelV1(model_dir, model):
    4     dummy_input = tf.constant(tf.zeros((1, 28, 28, 1)))  # Run the model once to initialize variables
    5     dummy_pred = model(dummy_input, training=False)
    6 
    7     saver = tfe.Saver(model.variables)  # Restore the variables of the model
    8     saver.restore(tf.train.latest_checkpoint(model_dir))

    2. 使用tf.train.Checkpoint

    代码如下

    1 step_counter = tf.train.get_or_create_global_step()
    2 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer, step_counter=step_counter)
    3 
    4 def saveModelV2(model_dir, checkpoint, modelname='model2'):
    5     checkpoint_prefix = os.path.join(model_dir, modelname)
    6     checkpoint.save(checkpoint_prefix)
    7 
    8 def restoreModelV2(model_dir, checkpoint):
    9     checkpoint.restore(tf.train.latest_checkpoint(model_dir))

    具体代码

    代码未严格按照总体流程的步骤,仅供参考,见https://github.com/darkknightzh/trainEagerMnist

    其中eagerFlag为使用eager的方式,0为不使用eager(使用静态图),1为使用V1的方式,2为使用V2的方式。当使用静态图时,不要加tfe.enable_eager_execution(),否则会报错。具体可参考代码。

  • 相关阅读:
    安装rqalpha的日志
    从github上下载一个csv文件
    PyQt4 里的表格部件的使用方法: QTableWidget
    markdown里的多层次列表项
    打包python脚本为exe的坎坷经历, by pyinstaller方法
    Spyder docstrings文档字符串的标准
    Plot Candlestick Charts in Research of quantopian
    另类之将ipython notebook嵌入blog方法
    Jupyter Notebook Tutorial: Introduction, Setup, and Walkthrough
    爬虫视频讲座
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/9989586.html
Copyright © 2011-2022 走看看