zoukankan      html  css  js  c++  java
  • tensorflow 使用预训练好的模型的一部分参数

       

    vars = tf.global_variables()

    net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name and 'word_embedding1s' not in var.name

    and 'proj_secondLayer' not in var.name

    ]

       

    saver_pre = tf.train.Saver(net_var)

       

    saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

       

    '''

    with tf.variable_scope('bi-lstm',reuse=True):

    fwk=tf.get_variable('bidirectional_rnn/fw/lstm_cell/kernel')

    fwb=tf.get_variable('bidirectional_rnn/fw/lstm_cell/bias')

    bwk = tf.get_variable('bidirectional_rnn/bw/lstm_cell/kernel')

    bwb = tf.get_variable('bidirectional_rnn/bw/lstm_cell/bias')

       

    saver_pre= tf.train.Saver({'words/_word_embeddings':self._word_embeddings,

    'bi-lstm/bidirectional_rnn/fw/lstm_cell/kernel':fwk,

    'bi-lstm/bidirectional_rnn/fw/lstm_cell/bias':fwb,

    'bi-lstm/bidirectional_rnn/bw/lstm_cell/kernel':bwk,

    'bi-lstm/bidirectional_rnn/bw/lstm_cell/bias':bwb})

    for x in tf.trainable_variables():

    print(x.name)

       

    #mysaver = tf.train.import_meta_graph(self.config.dir_model_storepath_pre_graph)

       

    saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

    '''

  • 相关阅读:
    maven资源文件的相关配置
    servlet-url-pattern匹配规则详细描述
    Spring的单例模式底层实现
    jsf--小项目--爱群小店
    jsf--页面循环跳转,项目内容递交
    查看MySQL路径
    HTML和XHTML的区别是什么
    Jsf 页面导航Navigation总结
    h:commandButton
    JSF--INTRODUCION
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/10330907.html
Copyright © 2011-2022 走看看