zoukankan      html  css  js  c++  java
  • 图网络学习笔记——ERNIE实践

    前言

    本次图网络学习笔记主要是百度PGL团队的七日优质免费课程,主要参考自这里:图神经网络7日打卡营.

    本次课程包含了很多内容:图游走来算法,图神经算法,尤其是NLP的预训练模型为基础的应用的ERNIESAGE模型取得了很好的效果,这里作为基础先记录一下ERNIE的实践过程。
    学习原理可参考:PPT.

    主要实践内容

    实践内容参考自持续学习语义理解框架ERNIE.
    首先要安装pip install paddle-ernie
    下面是部分代码与补充解释。

    1. 模型初始化
    # 导包
    import numpy as np
    import pandas as pd
    from sklearn.metrics import f1_score, accuracy_score
    from sklearn.preprocessing import LabelEncoder
    from sklearn.model_selection import train_test_split
    
    import paddle as P
    import paddle.fluid as F
    import paddle.fluid.layers as L
    import paddle.fluid.dygraph as D
    
    from ernie.tokenizing_ernie import ErnieTokenizer
    from ernie.modeling_ernie import ErnieModelForSequenceClassification
    
    
    # 数据格式为:词汇 + 类别 的DF格式 列
    BATCH = 32 # 根据显存确定
    MAX_SEQLEN = 64
    EPOCH = 10
    lr = 5e-5 # learning_rate
    n_class = 2 # 类别 假设为2分类
    
    D.guard().__enter__() # 启动动态图模式,这行是必须的!!!
    
    # 使用下载到本地的模型,默认是从网路下载,可能由于网络波动不一定每次都可以,下面是1.0版本
    ernie_path = r'pre-training_modelERNIEmodel-ernie1.0.1' 
    tokenizer = ErnieTokenizer.from_pretrained(pretrain_dir_or_url=ernie_path)
    ernie = ErnieModelForSequenceClassification.from_pretrained(pretrain_dir_or_url=ernie_path, num_labels=n_class)
    optimizer = F.optimizer.Adam(lr, parameter_list=ernie.parameters())
    
    1. 接下来就是数据处理:
    
    # 这里是makedata 和 get_batch_data 根据自己的需要设置自己的格式,参考源文档即可
    # 需要注意的是浮点型设置为float32,整型为int64,不然在数据类型上会报错
    

    构造训练验证集,这里与源文档不一样

    # 处理数据
    X = trains['file']
    y = trains['label']
    X_train,X_val,y_train,y_val=train_test_split(X.values,y.values,test_size=1/10,random_state=0)
    X_y_train = list(zip(X_train,y_train))
    X_y_val = list(zip(X_val, y_val))
    train_data = make_data(X_y_train)
    val_data = make_data(X_y_val)
    
    

    接下来就是自己训练,如果训练完成后,需要保存模型,而不是每次预测都训练一遍,这里经过熬夜阅读源码发现,

    if state_dict_path.with_suffix('.pdparams').exists():
                m, _ = D.load_dygraph(state_dict_path.as_posix())
                for k, v in model.state_dict().items():
                    if k not in m:
                        log.warn('param:%s not set in pretrained model, skip' % k)
                        m[k] = v # FIXME: no need to do this in the future
                model.set_dict(m)
    

    可以使用:
    首先保存为pdparams文件:F.save_dygraph(ernie.state_dict(),state_dir);
    然后下次加载的时候,有两种方式:

    • 用自己的权重文件替换掉本地的路径下的权重文件;
    • 先预加载与第一部分1)中一样,加载模型,然后使用自己的模型文件再次运行代码 ernie.set_dict(m)

    这样模型就可以根据自己的下游任务如图分类,文本分类,推荐系统的任务,来保存模型,供下次预测。其实为了搞懂这里花了好久时间才看到源码这样加载参数文件的(自己太菜了),然后使用自己的方法和权重文件。

    小结

    对图神经的学习,主要有游走类,端到端的图神经包括图采样等,其中有个感触很大的是ERNIE在预训练模型领域的杰出表现,甚至超过了大火的BERT,以及基于ERNIE实现的ERNIESAGE图模型的效果,因此对ERNIE实践一番,顺便学学源码。
    本文主要讲了三点:
    1)数据处理时的数据类型要注意;
    2) 结合下游任务的模型权重的保存;
    3) 权重保存后的使用方法。

    其他部分的讲解参考Paddle的官方链接应该就没有问题了。

  • 相关阅读:
    java.lang.ClassNotFoundException:org.springframework.web.context.ContextLoaderListener问题解决
    开发人员系统功能设计常用办公软件分享
    微信自定义菜单url默认80端口问题解决
    Servlet再度学习
    JSP九大内置对象
    linux下安装apache(httpd-2.4.3版本)各种坑
    Ajax原理学习
    Shell脚本了解
    生成Webservice的两种方式(Axis2,CXf2.x)
    Webservice发布
  • 原文地址:https://www.cnblogs.com/sxzhou/p/14057374.html
Copyright © 2011-2022 走看看