zoukankan      html  css  js  c++  java
  • Bert层数剪枝

    模型精简的流程如下:pretrian model -> retrain with new data(fine tuning) -> pruning -> retrain -> model

    对bert进行层数剪枝,保留第一层和第十二层参数,再用领域数据微调。代码如下:

    """
        test
    """
    import tensorflow as tf
    import os
    
    sess = tf.Session()
    last_name = 'bert_model.ckpt'
    model_path = 'bert_model/chinese_L-12_H-768_A-12'
    imported_meta = tf.train.import_meta_graph(os.path.join(model_path, last_name + '.meta'))
    imported_meta.restore(sess, os.path.join(model_path, last_name))
    init_op = tf.local_variables_initializer()
    sess.run(init_op)
    
    bert_dict = {}
    # 获取待保存的层数节点
    for var in tf.global_variables():
        # print(var)
        # 提取第0层和第11层和其它的参数,其余1-10层去掉,存储变量名的数值
        if var.name.startswith('bert/encoder/layer_') and not var.name.startswith(
                'bert/encoder/layer_0') and not var.name.startswith('bert/encoder/layer_11'):
            pass
        else:
            bert_dict[var.name] = sess.run(var).tolist()
    
    # print('bert_dict:{}'.format(bert_dict))
    # 真是保存的变量信息
    need_vars = []
    for var in tf.global_variables():
        if var.name.startswith('bert/encoder/layer_') and not var.name.startswith(
                'bert/encoder/layer_0/') and not var.name.startswith('bert/encoder/layer_1/'):
            pass
        elif var.name.startswith('bert/encoder/layer_1/'):
            # 寻找11层的var name,将11层的参数给第一层使用
            new_name = var.name.replace("bert/encoder/layer_1", "bert/encoder/layer_11")
            op = tf.assign(var, bert_dict[new_name])
            sess.run(op)
            need_vars.append(var)
            print(var)
        else:
            need_vars.append(var)
            print('####',var)
    
    # 保存model
    saver = tf.train.Saver(need_vars)
    saver.save(sess, os.path.join('bert_model/chinese_L-12_H-768_A-12_pruning', 'bert_pruning_2_layer.ckpt'))

     要修改对应的配置文件参数:

    效果总结

    在bert_base版本二分类模型的F1值达到97%,经过该方法裁剪后F1达到93.99%,损失在3个点左右,符合预期,还是可以投入工程使用的

  • 相关阅读:
    使用node.js如何爬取网站数据
    关于@font-face的使用
    webpack通过postcss-loader添加浏览器前缀
    点击弹出 +1放大效果 -- jQuery插件
    网站CSS选择器性能讨论
    修改 上传图片按钮input-file样式。。
    insertAdjacentHTML方法示例
    css背景色半透明的最佳实践
    js实现选中文字 分享功能
    js实现滑动的弹性导航
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13372797.html
Copyright © 2011-2022 走看看