zoukankan      html  css  js  c++  java
  • TensorFlow slim(二) 使用TF-slim编程模板(一)

      TF-slim 模块是TensorFLow中比较实用的API之一,是一个用于模型构建、训练、评估复杂模型的轻量化库

      最近,在使用TF-slim API编写了一些项目模型后,发现TF-slim模块在搭建网络模型时具有相同的编写模式。这个编写模式主要包含四个部分:

    • __init__():
    • build_model():
    • fit():
    • predict():

    1. __init__():

      这部分相当于是一个main()函数,其中包含参数的设置,模型整体的连接等操作。具体来说:

      a. 设置参数

      由于是类的构造函数,所以需要在其中设置一些模型网络结构的参数、模型训练时的参数等等。例如

    • 学习率
    • batch_size
    • 训练代数
    • 各种文件的存放地址
    • ...
    • 对于网络结构复杂的模型,还可以将网络结构的table以列表的形式进行保存。便于后续建立模型时可以循环获取每层的超参数。
    1 self.lr = lr
    2 self.batch_size = batch_size
    3 self.epoch = epoch
    4 self.checkpoint_dir_load = checkpoint_dir
    5 self.checkpoint_dir = os.path.join(checkpoint_dir, filename + ".ckpt")
    6 self.logdir = logdir
    7 self.result_dir = result_dir

      b. 设置输入、输出的占位符placeholder

      由于TF-slim框架仍然采用的是tensorflow的那一套,不像tf.keras可以使用keras.layer.Input(),所以还需要使用占位符。例如

    1 self.input_image = tf.placeholder(tf.float32, shape=[None, 6000])
    2 self.input_image_raw = tf.reshape(self.input_image, shape=[-1, 6000, 1])
    3 
    4 self.input_image_label = tf.placeholder(tf.float32, shape=[None, 1, 10])
    5 self.input_label = tf.reshape(self.input_image_label, shape=[-1, 10])

      c. 初始化网络结构,生成训练输出和测试输出

      用于后续损失的计算以及优化器的生成,以及训练结果和测试结果的调用。

      此处会涉及到网络参数的重用,需要使用tf.variable_scope()来管理参数。

    1 with tf.variable_scope("Network_Structure") as scope:
    2     self.train_digits = self.build_model(is_trained=True)
    3     scope.reuse_variables()
    4     self.test_digits = self.build_model(is_trained=False)

      d. 损失函数和优化器的声明

      此处损失声明使用的是 输出的占位符和训练的输出。例如:

    1 self.loss = slim.losses.softmax_cross_entropy(logits=self.train_digits, onehot_labels=self.input_label, scope="loss")
    2 
    3 self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(loss=self.loss)

      e. 最终训练输出结果和测试输出结果的计算

      由于网络输出的结果不一定是最终的结果。对于多分类问题,需要将one_hot编码的结果显示为类值;对于回归问题,输出结果可能会需要反归一化。等等..

      如下述代码,多分类问题的one_hot转化为类标签,并进行准确率的计算。

     1 # result and accuracy of test
     2 self.predicts = tf.math.argmax(self.test_digits, 1)   # 将one_hot转化为类标签
     3 self.test_correction = tf.equal(self.predicts, tf.math.argmax(self.input_label, 1))
     4 self.accuracy = tf.reduce_mean(tf.cast(self.test_correction, "float"))
     5 tf.summary.scalar("test_accuracy", self.accuracy)
     6 
     7 # result and accuracy of train
     8 self.train_result = tf.math.argmax(self.train_digits, 1)
     9 self.train_correlation = tf.equal(self.train_result, tf.math.argmax(self.input_label, 1))
    10 self.train_accuracy = tf.reduce_mean(tf.cast(self.train_correlation, "float"))
    11 tf.summary.scalar("train_accuracy", self.accuracy)

    2. build_model():【可以是别的名字】

      这部分是为了使用tf-slim搭建网络模型结构。有些模型可能一个函数实现不了,需要多个函数。例如具有共享层的Siamese Network,在共享层后还有其他层。

      这一部分也实现了如同tf.keras搭建的模型"乐高式"堆叠,不需要手动为各层生成权重、偏执等参数。也是代码瘦身的重要环节。

     1 with slim.arg_scope([slim.conv1d], padding="SAME", stride=2, activation_fn=tf.nn.relu,
     2                     weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
     3                     weights_regularizer=slim.l2_regularizer(0.005)
     4                     ):
     5     net = slim.conv1d(self.input_image_raw, num_outputs=16, kernel_size=8, padding="VALID", scope='conv_1')
     6     tf.summary.histogram("conv_1", net)
     7     net = slim.conv1d(net, num_outputs=16, kernel_size=8, scope='conv_2')
     8     tf.summary.histogram("conv_2", net)
     9     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_3")
    10     net = def_max_pool(net)
    11     # net = slim.nn.max_pool1d(net, ksize=2, strides=None, padding="VALID", data_format="NWC", name="max_pool_3")
    12     tf.summary.histogram("max_pool_3", net)
    13     net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_4")
    14     tf.summary.histogram("conv_4", net)
    15     net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_5")
    16     tf.summary.histogram("conv_5", net)
    17     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_6")
    18     net = def_max_pool(net)
    19     # net = slim.nn.max_pool1d(net, ksize=2, strides=1, padding="VALID", name="max_pool_6")
    20     tf.summary.histogram("max_pool_6", net)
    21     net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_7")
    22     tf.summary.histogram("conv_7", net)
    23     net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_8")
    24     tf.summary.histogram("conv_8", net)
    25     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_9")
    26     net = def_max_pool(net)
    27     # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_9")
    28     tf.summary.histogram("max_pool_9", net)
    29     net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_10")
    30     tf.summary.histogram("conv_10", net)
    31     net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_11")
    32     tf.summary.histogram("conv_11", net)
    33     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_12")
    34     net = def_max_pool(net)
    35     # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_12")
    36     tf.summary.histogram("max_pool_12", net)
    37     net = tf.reduce_mean(net, axis=1, name="global_max_pool_13")   # 起全局平均池化的作用
    38     tf.summary.histogram("global_max_pool_13", net)
    39     net = slim.dropout(net, keep_prob=0.5, scope="dropout")
    40     tf.summary.histogram("dropout", net)
    41     digits = slim.fully_connected(net, num_outputs=num_class, activation_fn=tf.nn.softmax, scope="fully_connected_14")
    42     tf.summary.histogram("fully_connected_14", digits)
    43 return digits

    3. fit():

      看名字就知道这一部分需要完成的是训练部分的代码

      这一部分需要包含会话的启动、模型保存器的初始化、循环迭代、batch设置、数据集输入、输出数据获取、喂到网络中、保存模型、会话关闭等操作。如下述代码

     1 sess = tf.Session()  # 启动会话
     2 
     3 merge_summary_op = tf.summary.merge_all()
     4 summary_writer = tf.summary.FileWriter(self.logdir, sess.graph)
     5 
     6 saver = tf.train.Saver(max_to_keep=1)  # 生成保存器
     7 sess.run(tf.global_variables_initializer())   # 变量激活
     8 
     9 for step in range(self.epoch):    # 迭代
    10     print("Epoch:%d"%step)
    11     avg_cost = 0
    12     acc = 0
    13     total_batch = int(input_x.shape[0]/self.batch_size)   # 划分batch
    14     for batch_num in range(total_batch):   # batch迭代
    15         # 获取数据
    16         batch_xs = input_x[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :]
    17         batch_ys = input_y[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :]
    18         batch_ys = sess.run(tf.one_hot(batch_ys, depth=10))
    19         # 喂到损失 优化器等等
    20         _, loss, acc = sess.run([self.optimizer, self.loss, self.train_accuracy],
    21                                                         feed_dict={self.input_image: batch_xs,
    22                                                          self.input_image_label: batch_ys})
    23         avg_cost += loss / total_batch
    24         acc += acc /total_batch
    25 
    26         summary_str = sess.run(merge_summary_op, feed_dict={self.input_image: batch_xs,
    27                                                             self.input_image_label: batch_ys})
    28         summary_writer.add_summary(summary_str, global_step=step)
    29         print("Epoch:%d, batch: %d, avg_cost: %g, accuracy: %g" % (step, batch_num, avg_cost, acc))
    30     # 保存模型
    31     saver.save(sess, self.checkpoint_dir, global_step=step)
    32 sess.close()   # 会话关闭

    4. predict():

      从函数名可以知道这一部分是实现预测部分的代码。其相对于训练的过程要更简单。主要包括会话的启动、保存器的生成、权重的导入(模型的恢复)、预测、关闭会话。如下述代码

     1 sess = tf.Session()   # 会话的启动
     2 
     3 saver = tf.train.Saver()  # 保存器的生成
     4 
     5 module_file = tf.train.latest_checkpoint(self.checkpoint_dir_load)
     6 saver.restore(sess, module_file)    # 模型的恢复
     7 
     8 input_y = sess.run(tf.one_hot(input_y, depth=10))  # 获取输出
     9 # 获取预测结果和预测精度
    10 predicts, acc_test = sess.run([self.predicts, self.accuracy], feed_dict={self.input_image: input_x,
    11 # 关闭会话                                                                            self.input_image_label: input_y})
    12 sess.close()
    13 # print("test_accuracy: %f" %acc_test)
    14 return predicts, acc_test

      上述四步完成后,便可以编写一个main函数来调用这个类,实现需要的功能。.fit()和.predict()主要是在main()函数来调用。

  • 相关阅读:
    oracle——定时器时间设置
    Servlet上下文监听
    jsp开发中页面数据共享技术
    String类的创建
    Microsoft Enterprise Library 5.0 系列 Configuration Application Block
    How to print only the selected grid rows
    企业库的保存方法
    devexpress pictureedit 按钮调用其菜单功能
    Devexpress IDE 插件 卸载
    修改安装包快捷方式路径
  • 原文地址:https://www.cnblogs.com/monologuesmw/p/12631901.html
Copyright © 2011-2022 走看看