zoukankan      html  css  js  c++  java
  • tensorflow 加载预训练模型进行 finetune 的操作解析

    这是一篇需要仔细思考的博客;

    预训练模型

    tensorflow 在 1.0 之后移除了 models 模块,这个模块实现了很多模型,并提供了部分预训练模型的权重;

    图像识别模型的权重下载地址  https://github.com/tensorflow/models/tree/master/research/slim

    模型加载

    首先需要了解模型保存的形式,包含了 checkpoint、data、meta 等文件;

    模型加载不仅可以从 data 加载训练好的权重,还可以从 meta 加载计算图,

    加载计算图我们可以理解为引入了 计算节点和变量,引入变量很重要,这样我们无需自己去创造变量,

    加载计算图返回的是个 Saver 对象,如果没有通过 加载图引入变量,也没有自己创造变量,是无法创建 Saver 对象的,没有 Saver 对象,就无法加载预训练权重;

    下面我用代码解释上面的逻辑;

    首先做个数据准备:写个最简单的计算图,然后保存

    with tf.name_scope('scope1'):
        w1 = tf.Variable(1, name='w1')
        v1 = tf.Variable(2, name='v1')
    
    with tf.name_scope('scope2'):
        w2 = tf.Variable(3, name='w2')
        v2 = tf.Variable(4, name='v2')
    
    out = tf.add(w1*v1, w2*v2)
    init = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    
    sess = tf.Session()
    sess.run(init)
    print(sess.run(out))
    saver.save(sess, 'data/test.ckpt')

    记住这里 w1 的值为 1,后面有用

    加载预训练权重有两种思路

    1. 先加载计算图,再加载权重

    2. 自己创建计算图,再加载权重

    总结一句话就是 先有变量(创建图就必须创建变量),然后创建 Saver 对象,通过 Saver 对象加载权重

    先加载计算图,再加载权重

    加载计算图用下面的方法

    def import_meta_graph(meta_graph_or_file,
                          clear_devices=False,
                          import_scope=None,
                          **kwargs):
      """Recreates a Graph saved in a `MetaGraphDef` proto."""

    加载计算图后,通过 sess.graph 获取图,通过 get_tensor_by_name 等方法获取保存的变量

    简单举例

    ## 加载了图中的各个节点,相当于引入变量
    saver = tf.train.import_meta_graph('data/test.ckpt.meta')       # 从 meta 文件直接加载图,返回一个存储器
    print(type(saver))      # <class 'tensorflow.python.training.saver.Saver'>
    
    ### 如果没有变量,直接创建存储器,会报错的
    # saver = tf.train.Saver()      ### 报错 ValueError: No variables to save
    
    sess = tf.Session()
    saver.restore(sess, 'data/test.ckpt')   ### 获取图权重
    # print(sess.run('w1'))           ### 直接这样报错
    
    graph = sess.graph  ### 获取加载的图
    print(sess.run(graph.get_tensor_by_name('scope1/w1:0')))        # 1     ### 获取图的 Tensor
    # print(sess.run(graph.get_tensor_by_name('w1:0')))       ### 报错 KeyError: "The name 'w1:0' refers to a Tensor which does not exist. The operation, 'w1', does not exist in the graph."
    sess.close()

    自己创建计算图,再加载权重

    注意,自己创建的计算图要与保存的计算图一致,包括 网络结构、作用域、变量名等

    ##### 上面是直接加载图,引入变量,从而创建存储器
    ##### 这里我们不加载,自己创建一个图,从而创建存储器
    ### 注意,图的结构要与 checkpoint 中的一致,作用域、变量名 都要一样
    with tf.name_scope('scope1'):
        w1 = tf.Variable(333, name='w1')
        v1 = tf.Variable(2, name='v1')
    
    with tf.name_scope('scope2'):
        w2 = tf.Variable(3, name='w2')
        v2 = tf.Variable(4, name='v2')
    
    init = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    
    sess = tf.Session()
    sess.run(init)
    saver.restore(sess, 'data/test.ckpt')
    
    graph = tf.get_default_graph()      ### 获取默认图
    print(sess.run(graph.get_tensor_by_name('scope1/w1:0')))        # 1

    这里我们创建计算图是把 w1 赋值 333,而加载的 w1 仍然是保存时的 1,说明加载成功了

    加载预训练权重进行 finetune

    finetune 我会单独写一篇博客,这里不细讲,它大致可以分两种:

    1. 加载部分权重,其他权重正常初始化,然后优化所有参数;

    2. 加载部分权重,其他权重正常初始化,然后优化非加载的参数,相当于固定部分权重,优化另一部分权重;

    需要加载的权重 肯定通过 预训练模型获取,而其他权重则既可以通过预训练的模型获取,也可以自己创建,这点在上一章已经讲清楚了,下面我们为了方便,就直接加载计算图了;

    固定部分权重是个难点,它的思路有两点:

    1. 先加载这部分权重,然后把这部分权重经过前向计算,得到一个新的 Input_new,然后把这个 Input 作为后续网络的输入,反向传播时传到 Input 肯定就停止了;

    2. 分别加载需要训练的参数 train_var 和不需要训练的参数 fixed_var,然后在 优化器的 optimizer.minimize 方法中指定 var_list 为 train_var

    def minimize(self, loss, global_step=None, var_list=None,
                   gate_gradients=GATE_OP, aggregation_method=None,
                   colocate_gradients_with_ops=False, name=None,
                   grad_loss=None):
        """Add operations to minimize `loss` by updating `var_list`.
          var_list: Optional list or tuple of `Variable` objects to update to
            minimize `loss`.  Defaults to the list of variables collected in
            the graph under the key `GraphKeys.TRAINABLE_VARIABLES`."""

    创造计算图,先加载 fixed_var,再添加新的网络层

    demo 如下:只是伪代码哦

    tf.reset_default_graph()        ### 这句暂时可忽略
    
    # 构建计算图
    images = tf.placeholder(tf.float32,(None,224,224,3))
    with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
        logits, endpoints = mobilenet_v2.mobilenet(images,depth_multiplier=1.4)
    
    with tf.variable_scope("finetune_layers"):
        mobilenet_tensor = tf.get_default_graph().get_tensor_by_name("MobilenetV2/expanded_conv_14/output:0")       # 获取目标张量,取出mobilenet中指定层的张量
    
        # 将张量作为新的 Input 向新层传递
        x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_1")(mobilenet_tensor)
        x = tf.nn.relu(x,name="relu_1")
        x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_2")(x)
        x = tf.layers.Conv2D(10,3,name="conv2d_3")(x)
        predictions = tf.reshape(x, (-1,10))

    分别获取 train_var 和 fixed_var,在 minimize 中指定 var_list

    demo 如下:只是伪代码哦

    #### 不重要的函数
    def get_var_list(target_tensor=None):
        '''获取指定变量列表 var_list 的函数;
           具体怎么干的,无需关心,只需要知道它的作用是获取一批权重
        '''
        if target_tensor==None:
            target_tensor = r"MobilenetV2/expanded_conv_14/output:0"
        target = target_tensor.split("/")[1]
        all_list = []
        all_var = []
    
        for var in tf.global_variables():
            if var != []:
                all_list.append(var.name)
                all_var.append(var)
        try:
            all_list = list(map(lambda x:x.split("/")[1],all_list))
            # 查找对应变量作用域的索引
            ind = all_list[::-1].index(target)
            ind = len(all_list) -  ind - 1
            print(ind)
            del all_list
            return all_var[:ind+1]
        except:
            print("target_tensor is not exist!")
            
    #### 下面这一堆仔细看
    x_train = np.random.random(size=(141,224,224,3))
    y_train = to_categorical(label_fake,10)
    
    y_label = tf.placeholder(tf.int32, (None,10))
    
    ### 收集变量作用域 finetune_layers 内的可训练变量,作为 train_var
    train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="finetune_layers")
    
    loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_label,logits=logits)
    ### 定义优化方法,用 var_list 指定需要更新的权重,此时仅更新 train_var 权重
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss,var_list=train_var)
    
    epochs = 10
    batch_size = 16
    
    # 目标张量名称,要获取变量列表 fixed_var
    target_tensor = "MobilenetV2/expanded_conv_14/output:0"
    fixed_var = get_var_list(target_tensor)
    saver = tf.train.Saver(var_list=fixed_var)
    
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        writer = tf.summary.FileWriter(r"./logs", sess.graph)
        ## 初始化 train_var, 使用初始化指定函数
        sess.run(tf.variables_initializer(var_list=train_var))
        saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))
    
        for i in range(2000):
            start = (i*batch_size) % x_train.shape[0]
            end = min(start+batch_size, x_train.shape[0])
            _, merge, losses = sess.run([train_step,merge_all,loss], feed_dict={images:x_train[start:end], y_label:y_train[start:end]})
            if i%100==0: writer.add_summary(merge, i)

    重点就下面 3 句

    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss,var_list=train_var)
    sess.run(tf.variables_initializer(var_list=train_var))    
    saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))

    深入理解一下:

    1. 如果初始化了 train_var 后,又加载了所有预训练变量,也就是说 train_var 的初始值也是 预训练的,而不是平常的 全0、全1、高斯分布等,这是可以的;

    2. 如果先 加载所有预训练变量,然后初始化 train_var,也是可以的,因为 fixed_var 没有重新初始化,还是 预训练的值,而 train_var 初始值是多少没那么重要;

    3. 如果初始化 train_var,加载 fixed_var,则谁先谁后无所谓;

    4. 如果是 先用 tf.global_variables_initializer() 初始化全部参数,再加载全部预训练参数,也是可以的;

    5. 如果先加载全部预训练参数,在用 tf.global_variables_initializer() 初始化全部参数,是不可以的,因为 fixed_var 也被初始化了;

    上述代码采用的是第 3 种,指定了只加载 fixed_var

    saver = tf.train.Saver(var_list=fixed_var)

    参考资料:

    https://zhuanlan.zhihu.com/p/42183653

  • 相关阅读:
    python接口自动化(十三)--cookie绕过验证码登录(详解)
    python接口自动化(十二)--https请求(SSL)(详解)
    python接口自动化(十一)--发送post【data】(详解)
    python接口自动化(十)--post请求四种传送正文方式(详解)
    python接口自动化(九)--python中字典和json的区别(详解)
    python接口自动化(八)--发送post请求的接口(详解)
    python接口自动化(七)--状态码详解对照表(详解)
    python接口自动化(六)--发送get请求接口(详解)
    python接口自动化(五)--接口测试用例和接口测试报告模板(详解)
    Redis的简介
  • 原文地址:https://www.cnblogs.com/yanshw/p/12432595.html
Copyright © 2011-2022 走看看