zoukankan      html  css  js  c++  java
  • 【算法】Bert预训练源码阅读

    Bert预训练源码

    主要代码

    地址:https://github.com/google-research/bert

    1. create_pretraning_data.py:原始文件转换为训练数据格式
    2. tokenization.py:汉字,单词切分,复合词处理,create_pretraning_data中调用
    3. modeling.py: 模型结构
    4. run_pretraing.py: 运行预训练

    tokenization.py

    作用:句子切分,特殊符号处理。
    主要类:BasicTokenizer, WordpieceTokenizer, FullTokenizer

    1. BasicTokenizer.tokenize: 文本转为unicode, 去除特殊符号,汉字前后加空格,按空格切分单词,去掉文本重音,按标点符号切割单词。最后生成一个list
    2. WordpieceTokenizer.tokenize: 长度过长的单词标记为UNK,复合词切分,找不到的词标记为UNK
    3. FullTokenizer:先后调用BasicTokenizer和WordpieceTokenizer

    create_pretraning_data.py

    输入:词典, 原始文本(空行分割不同文章,一行一句)
    输出:训练数据
    作用:生成训练数据,句子对组合,单词mask等
    入口函数main

    1. 加载词典,加载原始文本

    2. create_training_instances
      读取原始文本文件,做unicode转换,中文,标点,特殊符号处理,空格切分,复合词切分。转换为[[[first doc first sentence],[first doc second sentence],[first doc third sentence]],[[second doc first sentence],[]],....] 这样的结构
      去除空文章,文章顺序打乱
      输入的原始文本会重复使用dupe_factor次

    3. 对每一篇文章生成训练数据create_instances_from_document
      训练语句长度限制max_seq_length,0.1的概率生成长度较小的训练语句,增加鲁棒性
      句子对(A,B)随机组合
      对于一篇文章,按顺序获取n行句子,其长度总和限制为target_seq_length,
      随机选取n行中的前m行作为A
      0.5的概率,B是n行中后面剩余的部分;其他情况,B是随机选取的其他文章内容,开始位置是随机的
      文章中没有使用的部分继续组合(A, B)
      添加CLS,SEP分隔符,生成句子向量
      对句子对中的单词做随机mask (create_masked_lm_predictions), 随机取num_to_predict个单词做mask,0.8的概率标记为MASK,0.1的概率标记为原始单词,0.1的概率标记为随机单词
      封装,句子对,句子id,是否为随机下一句,mask的下标位置,mask对应的原始单词

    4. 训练数据序列化,存入文件。单词转为id,句子长度不足的后面补0。

    modeling.py

    BertConfig: 配置
    BertModel: 模型主体

    建模主体过程:

    1. 获取词向量 [batch_size, seq_length, embedding_size]
    2. 添加句向量,添加位置向量,在最后一个维度上做归一化,整体做dropout
    3. transformer
      全连接映射 [B*F, embedding_size]->[B*F, N*H]
      (dropout(softmax(QK^T))V), 其中mask了原本没有数据的部分
      全连接,dropout,残差处理,归一化,全连接,dropout,残差处理,归一化
      上述循环多层
      取最终[CLS]对应的向量做句向量

    run_pretraining.py

    作用:生成目标函数,加载已有参数,迭代训练
    主要函数:model_fn_builder

    1. 评估mask单词的预测准确性,整体loss为mask处预测对的分数的平均值
    2. 评估next_sentence预测准确性,loss为预测对的概率值
    3. 总损失为上面两个损失相加
  • 相关阅读:
    从《兄弟连》到团队管理
    将来
    [译] TypeScript入门指南(JavaScript的超集)
    基于cocos2dx迷宫游戏
    SVN版本管理教程
    arcgis for android 本地缓存
    vs2010变的特别卡解决办法
    cocos2d-x自适应屏幕
    cocos2d-x使用CCScale9Sprite
    cocos2dx开发入门文档
  • 原文地址:https://www.cnblogs.com/dplearning/p/10397935.html
Copyright © 2011-2022 走看看