zoukankan      html  css  js  c++  java
  • paddlepaddle如何预加载embedding向量

    使用小批量数据时,模型容易过拟合,所以需要对全量数据进行处理,我是用的是word2vec训练的词向量. 那么训练好对词向量如何加载呢?

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    """
    -------------------------------------------------
       Version     :       None
       File Name   :       paddle_load_w2v
       Description :       None
       Author      :       gongxijun
       Email       :      
       date        :       2019-12-04
    -------------------------------------------------
       Change Activity:
                       2019-12-04:
    -------------------------------------------------
    """
    from __future__ import absolute_import
    from __future__ import print_function
    from __future__ import unicode_literals
    
    __author__ = 'gongxijun'
    import paddle
    import paddle.fluid as fluid
    import paddle.fluid.layers as layers
    import paddle.fluid.nets as nets
    import numpy as np
    import math
    import codecs
    from huangdao.dataset import data_feeder
    
    
    def load_parameter(file_name):
        embeddings = []
        words = []
        with codecs.open(file_name, 'r',encoding="utf8") as f:
            header = f.readline()
            vocab_size, vector_size = map(int, header.split())
            for line in range(vocab_size):
                word_list = f.readline().strip("
    ").strip(" ").split(' ')
                word = word_list[0]
                vector = word_list[1:]
                words.append(word if len(word) > 0 else "unk")
                assert len(vector) == vector_size, "{} {}".format(len(vector), vector_size)
                embeddings.append(np.array(vector))
        assert len(words) == len(embeddings)
        return words, embeddings
    
    
    word_dict_len = 74378
    word_dim = 128
    
    
    def get_embedding(name, shape, is_sparse=True, dtype='int64'):
        """
        :param name:
        :param is_categorical: bool 是否是类标签
        :param shape: must be (a,b)
        :param dtype:
        :param is_sparse: bool
        :return:
        """
        alias_id = layers.data(name=name, shape=[1], dtype=dtype)
        assert len(shape) == 2, '{} must equal 2'.format(len(shape))
        alias_emb = layers.embedding(input=alias_id, size=shape,
                                     param_attr=fluid.param_attr.ParamAttr(name="embedding_{}".format(name)),
                                     is_sparse=is_sparse)
        alias_fc = layers.fc(input=alias_emb, size=shape[1])
        return alias_fc
    
    
    words_emb = get_embedding("words", shape=(word_dict_len, word_dim))
    
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    
    exe.run(fluid.default_startup_program())
    embedding_param = fluid.global_scope().find_var(
        "embedding_words").get_tensor()
    words, embeddings = load_parameter("/Users/gongxijun/data/item2vec.txt")
    embedding_param.set(embeddings, place)
  • 相关阅读:
    Some good websites for C++
    Static Class in C#
    js提示后跳转代码集合
    日期格式化函数
    URL伪静态
    正则的一些使用
    提高.net网站的性能
    验证DropDownList的方法
    用C#去除字符串中HTML的格式
    drepdownlist不能动态绑定数据的原因
  • 原文地址:https://www.cnblogs.com/gongxijun/p/11988475.html
Copyright © 2011-2022 走看看