zoukankan      html  css  js  c++  java
  • 【转载】 Tensorflow如何直接使用预训练模型(vgg16为例)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
    本文链接:https://blog.csdn.net/weixin_44633882/article/details/89054159
     

    ------------------------------------------------------------------------------------------

    主流的CNN模型基本都会使用VGG16或者ResNet等网络作为预训练模型,正好有个朋友和我说发给他一个VGG16的预训练模型和代码,我就整理了一下。在这里也分享一下,方便大家直接使用。

    系统环境

    • Tensorflow-gpu 1.12.0
    • Python 3.5.2

    资料来源

    官方slim说明
    https://github.com/tensorflow/models/tree/1af55e018eebce03fb61bba9959a04672536107d/research/slim

    主页里直接可以看到所提供的模型列表和下载链接。

    我们选择vgg16来做个示范哈,虽然vgg16的准确率现在已经不算高。

    拿到vgg_16.ckpt模型文件!
    直接贴上代码

    vgg16预训练模型使用代码

    import os
    import numpy as np
    import tensorflow as tf
    slim = tf.contrib.slim
    PROJECT_PATH = os.path.dirname(os.path.abspath(os.getcwd()))
    # 预训练模型位置
    tf.app.flags.DEFINE_string('pretrained_model_path',  os.path.join(PROJECT_PATH, 'data/vgg_16.ckpt'), '')
    FLAGS = tf.app.flags.FLAGS
    
    def vgg_arg_scope(weight_decay=0.1):
      """定义 VGG arg scope.
      Args:
        weight_decay: The l2 regularization coefficient.
      Returns:
        An arg_scope.
      """
      with slim.arg_scope([slim.conv2d, slim.fully_connected],
                          activation_fn=tf.nn.relu,
                          weights_regularizer=slim.l2_regularizer(weight_decay),
                          biases_initializer=tf.zeros_initializer()):
        with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
          return arg_sc
    
    def vgg16(inputs,scope='vgg_16'):
        with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
            # Collect outputs for conv2d, fully_connected and max_pool2d.
            with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],):
                                # outputs_collections=end_points_collection):
                net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
                net = slim.max_pool2d(net, [2, 2], scope='pool1')
                net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
                net = slim.max_pool2d(net, [2, 2], scope='pool2')
                net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
                net = slim.max_pool2d(net, [2, 2], scope='pool3')
                net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
                net = slim.max_pool2d(net, [2, 2], scope='pool4')
                net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
                # net = slim.max_pool2d(net, [2, 2], scope='pool5')
                # net = slim.fully_connected(net, 4096, scope='fc6')
                # net = slim.dropout(net, 0.5, scope='dropout6')
                # net = slim.fully_connected(net, 4096, scope='fc7')
                # net = slim.dropout(net, 0.5, scope='dropout7')
                # net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')
            return net
    
    def net():
        input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
        with slim.arg_scope(vgg_arg_scope()):
            conv5_3 = vgg16(input_image)      # vgg16网络
    
        init = tf.global_variables_initializer()
        # restore预训练模型op
        if FLAGS.pretrained_model_path is not None:
            variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
                                                                 slim.get_trainable_variables(),
                                                                 ignore_missing_vars=True)
        with tf.Session() as sess:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                # resotre 预训练模型
                variable_restore_op(sess)
            a = sess.run([conv5_3],feed_dict={input_image:np.arange(360000).reshape(1,300,400,3)})
    
    if __name__ == '__main__':
        net()
        print(tf.trainable_variables())

    讲一讲,代码里要注意的地方吧,也比较简单易懂。

    1.vgg_arg_scope

    def vgg_arg_scope(weight_decay=0.1):
      with slim.arg_scope([slim.conv2d, slim.fully_connected],
                          activation_fn=tf.nn.relu,
                          weights_regularizer=slim.l2_regularizer(weight_decay),
                          biases_initializer=tf.zeros_initializer()):
        with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
          return arg_sc

    vgg_arg_scope()函数返回了一个scope参数空间,使用起来就是with slim.arg_scope(vgg_arg_scope()):

    它规定了[slim.conv2d, slim.fully_connected]都要满足什么变量参数,比如:激活函数,参数初始化。


    activation_fn=tf.nn.relu来说,所有在这个变量空间中的conv2d卷积和fully_connected全连接都是指定了relu作为激活函数。

    当然,这里存在覆盖是可以的,可以嵌套arg_scope进行设置,内层空间覆盖了外层空间,最内层的就是slim.conv2d()里传入指定的参数了,这是覆盖了所有外层的。变量空间在我看来,非常方便,也使网络定义变得简单。

    2.slim.repeat()

    VGG16中比如一个conv,其中做了3次相同的卷积,写出来的代码就很长,使用repeat()就简单一句话net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')增强了代码可读性,而有人可能会问,那三层卷积层怎么进行标识呢?
    当然没问题,你输出变量会发现是类似conv5/conv5_1,在_后面递增自动标记区分。

    3.代码里是每个层是如何拿到自己对应的模型参数呢?

    这个应该是有些人的困惑吧,毕竟不知道这个,也只能拿着代码直接用。这个的关键是变量空间
    网络定义完成了,你可以通过 print(tf.trainable_variables()) 来获得所有网络中的变量。
    我贴出来 vgg16 中的变量,太多了,捡重要的说,就说说 conv1,可以看到变量是这么标识的 vgg_16/conv1/conv1_1/weights,前面有很多前缀,就和龙母报出来自己一堆头衔一样,其实是起到一个定位效果。
     

        # [<tf.Variable 'vgg_16/conv1/conv1_1/weights:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
        # <tf.Variable 'vgg_16/conv1/conv1_1/biases:0' shape=(64,) dtype=float32_ref>,
        # <tf.Variable 'vgg_16/conv1/conv1_2/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
        # <tf.Variable 'vgg_16/conv1/conv1_2/biases:0' shape=(64,) dtype=float32_ref>,
        # <tf.Variable 'vgg_16/conv2/conv2_1/weights:0' shape=(3, 3, 64, 128) dtype=float32_ref>,

    在代码里,我们要让每个层在预训练模型里找到自己对应的参数,就必须这么定义变量空间。

        with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
            with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d]):
                net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')

    看到了 scope 和 ‘vgg_16’两个,其实 scope 我们也传入的是 ’vgg_16’,tf.variable_scope() 的参数,前两个是 name_or_scope, default_name默认名称是当 name_or_scope 为空时,使用的默认名称。

    这么整理一下,'vgg_16',  后面的slim.repeat()里的scope='conv1',还有自动标记的 conv1_1
    连起来就是 vgg_16/conv1/conv1_1

    4.  预训练模型restore。

    先准备op,而且若 pretrained_model_path 不为空,才加入和使用 variable_restore_op

    if FLAGS.pretrained_model_path is not None:
        variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
                                                                slim.get_trainable_variables(),
                                                                ignore_missing_vars=True)

    Session()中使用

    if FLAGS.pretrained_model_path is not None:
        variable_restore_op(sess)

    讲解完毕!哦,还有补充一下,一般vgg16来说,只会拿conv5_3的输出,继续做fine-tune。所以,你只用conv5_3,测试的时候是不用在意输入图片的大小的,因为都是卷积嘛。但是,我测试的时候,传入了个(1,3,6,3)的数组,出现了这么一个错。想了想,嗯,应该是这个数组做不了那么多次卷积的,所以Tensorflow报错了。(这里只是简单记录一下),所以用一个大一些的数组传入就可以啦

    2019-04-06 12:20:14.650154: F tensorflow/stream_executor/cuda/cuda_dnn.cc:542] Check failed: cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd, dims.data(), strides.data()) == CUDNN_STATUS_SUCCESS (3 vs. 0)batch_descriptor: {count: 1 feature_map_count: 128 spatial: 0 1  value_min: 0.000000 value_max: 0.000000 layout: BatchDepthYX}
    bash: line 1:  2492 Aborted                 (core dumped) env "PYTHONUNBUFFERED"="1" "PYTHONPATH"="/tmp/pycharm_project_299:/home/benke/.pycharm_helpers/pycharm_matplotlib_backend" "PYCHARM_HOSTED"="1" "JETBRAINS_REMOTE_RUN"="1" "PYCHARM_MATPLOTLIB_PORT"="65407" "PYTHONIOENCODING"="UTF-8" '/opt/anaconda3/bin/python' '-u' '/tmp/pycharm_project_299/data/vgg.py'

    ---------------------------------------------------------------------------------------------

    转者注:

    tensorflow官方预训练模型下载链接:

    https://github.com/tensorflow/models/tree/master/research/slim

  • 相关阅读:
    观察者模式
    系统高并发网络图书室
    java keytool
    ant 脚本使用技巧
    Unsupported major.minor version 51.0 错误解决方案
    Oracle的网络监听配置
    win8 JDK环境变量不生效
    javax.mail
    xmlrpc
    网络时间同步
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/12427443.html
Copyright © 2011-2022 走看看