zoukankan      html  css  js  c++  java
  • tensorflow 使用预训练好的模型的一部分参数

       

    vars = tf.global_variables()

    net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name and 'word_embedding1s' not in var.name

    and 'proj_secondLayer' not in var.name

    ]

       

    saver_pre = tf.train.Saver(net_var)

       

    saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

       

    '''

    with tf.variable_scope('bi-lstm',reuse=True):

    fwk=tf.get_variable('bidirectional_rnn/fw/lstm_cell/kernel')

    fwb=tf.get_variable('bidirectional_rnn/fw/lstm_cell/bias')

    bwk = tf.get_variable('bidirectional_rnn/bw/lstm_cell/kernel')

    bwb = tf.get_variable('bidirectional_rnn/bw/lstm_cell/bias')

       

    saver_pre= tf.train.Saver({'words/_word_embeddings':self._word_embeddings,

    'bi-lstm/bidirectional_rnn/fw/lstm_cell/kernel':fwk,

    'bi-lstm/bidirectional_rnn/fw/lstm_cell/bias':fwb,

    'bi-lstm/bidirectional_rnn/bw/lstm_cell/kernel':bwk,

    'bi-lstm/bidirectional_rnn/bw/lstm_cell/bias':bwb})

    for x in tf.trainable_variables():

    print(x.name)

       

    #mysaver = tf.train.import_meta_graph(self.config.dir_model_storepath_pre_graph)

       

    saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

    '''

  • 相关阅读:
    (转 )Unity对Lua的编辑器拓展
    unity timeline
    unity拖尾粒子问题
    unity shader 波动圈
    linux教程
    Unity Shader 基础
    ugui拖拽
    unity shader 热扭曲 (屏幕后处理)
    英文取名神器
    lua正则表达式替换字符串
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/10330907.html
Copyright © 2011-2022 走看看