zoukankan      html  css  js  c++  java
  • tensorflow finuetuning 例子

    最近研究了下如何使用tensorflow进行finetuning,相比于caffe,tensorflow的finetuning麻烦一些,记录如下:

    1.原理

    finetuning原理很简单,利用一个在数据A集上已训练好的模型作为初始值,改变其部分结构,在另一数据集B上(采用小学习率)训练的过程叫做finetuning。

    一般来讲,符合如下情况会采用finetuning

    • 数据集A和B有相关性
    • 数据集A较大
    • 数据集B较小

    2.关键代码

    在数据集A上训练的时候,和普通的tensorflow训练过程完全一致。但是在数据集B上进行finetuning时,需要先从之前训练好的checkpoint中恢复模型参数,这个地方比较关键,

    需要注意只恢复需要恢复的参数,其他参数不要恢复,否则会因为找不到的声明而报错。以mnist为例子,如果我想先训练一个0-7的8类分类器,网络结构如下:

    conv1-conv2-fc8(其他不带权重的pooling、softmaxloss层忽略)

    然后我想用这个训练出的模型参数,在0-9的10类分类器上做finetuning,网络结构如下:

    conv1-conv2-fc10

    那么在从checkpoint中恢复模型参数时,我只能恢复conv1-conv2,如果连fc8都恢复了,就会因为找不到fc8的定义而报错

    以上描述对应的代码如下:

    1     if tf.train.latest_checkpoint('ckpts') is not None:
    2         trainable_vars = tf.trainable_variables()
    3         res_vars = [t for t in trainable_vars if t.name.startswith('conv')]
    4         saver = tf.train.Saver(var_list=res_vars)
    5         saver.restore(sess, tf.train.latest_checkpoint('ckpts'))
    6     else:
    7         saver = tf.train.Saver()

    3.demo

    利用mnist写了一个简单的finetuning例子,大家可以试试,事实证明,利用一个相关的已有模型做finuetuning比从0开始训练收敛的更快并且收敛到的准确率更高,

    点我下载

  • 相关阅读:
    leetcode刷题笔记 217题 存在重复元素
    leetcode刷题笔记 二百零六题 反转链表
    leetcode刷题笔记 二百零五题 同构字符串
    20201119日报
    np.percentile 和df.quantile 分位数
    建模技巧
    np.where() 条件索引和SQL的if用法一样,或者是给出满足条件的坐标集合
    np.triu_indices_from() 返回方阵的上三角矩阵的索引
    ax.set_title() 和 plt.title(),以及df,plot(title='')
    信用卡模型(三)
  • 原文地址:https://www.cnblogs.com/hrlnw/p/9299810.html
Copyright © 2011-2022 走看看