zoukankan      html  css  js  c++  java
  • pytorch --- word2vec 实现 --《Efficient Estimation of Word Representations in Vector Space》

    论文来自Mikolov等人的《Efficient Estimation of Word Representations in Vector Space》

    论文地址: 66666

    论文介绍了2个方法,原理不解释...

    skim code and comment https://github.com/graykode/nlp-tutorial:

    # -*- coding: utf-8 -*-
    # @time : 2019/11/9  12:53
    
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.autograd import Variable
    import matplotlib.pyplot as plt
    
    dtype = torch.FloatTensor
    
    # 3 Words Sentence
    sentences = [ "i like dog", "i like cat", "i like animal",
                  "dog cat animal", "apple cat dog like", "dog fish milk like",
                  "dog cat eyes like", "i like apple", "apple i hate",
                  "apple i movie book music like", "cat dog hate", "cat dog like"]
    
    word_sequence = " ".join(sentences).split()
    word_list = " ".join(sentences).split()
    word_list = list(set(word_list))
    word_dict = {w: i for i, w in enumerate(word_list)}
    
    # Word2Vec Parameter
    batch_size = 20  # To show 2 dim embedding graph
    embedding_size = 2  # To show 2 dim embedding graph
    voc_size = len(word_list)
    
    # 产生 batch_size个,每个都是一个input和label, both are ont-hot vector
    def random_batch(data, size):
        random_inputs = []
        random_labels = []
        random_index = np.random.choice(range(len(data)), size, replace=False)
    
        for i in random_index:
            random_inputs.append(np.eye(voc_size)[data[i][0]])  # target
            random_labels.append(data[i][1])  # context word
    
        return random_inputs, random_labels
    
    # Make skip gram of one size window
    skip_grams = []
    # 从第2个word_sequence开始(index=1),预测index=0和index=2,也就是[index=1,index=0]和[index=1,index=2]的添加到skim_grams中
    for i in range(1, len(word_sequence) - 1):
        target = word_dict[word_sequence[i]]
        context = [word_dict[word_sequence[i - 1]], word_dict[word_sequence[i + 1]]]
    
        for w in context:
            skip_grams.append([target, w])
    
    # Model
    class Word2Vec(nn.Module):
        def __init__(self):
            super(Word2Vec, self).__init__()
    
            # W and WT is not Traspose relationship
            self.W = nn.Parameter(-2 * torch.rand(voc_size, embedding_size) + 1).type(dtype) # voc_size > embedding_size Weight
            self.WT = nn.Parameter(-2 * torch.rand(embedding_size, voc_size) + 1).type(dtype) # embedding_size > voc_size Weight
    
        def forward(self, X):
            # X : [batch_size, voc_size]
            hidden_layer = torch.matmul(X, self.W) # hidden_layer : [batch_size, embedding_size]
            output_layer = torch.matmul(hidden_layer, self.WT) # output_layer : [batch_size, voc_size]
            return output_layer
    
    model = Word2Vec()
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training
    for epoch in range(5000):
    
        input_batch, target_batch = random_batch(skip_grams, batch_size)
    
        input_batch = Variable(torch.Tensor(input_batch))
        target_batch = Variable(torch.LongTensor(target_batch))
    
        optimizer.zero_grad()
        output = model(input_batch)
    
        # output : [batch_size, voc_size], target_batch : [batch_size] (LongTensor, not one-hot)
        loss = criterion(output, target_batch)
        if (epoch + 1)%1000 == 0:
            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    
        loss.backward()
        optimizer.step()
    
    # because
    # input_size is [batch_size,voc_size] , ( a word is one-hot voctor(lenght is voc_size) )
    # W is [voc_size,emmedding_size]
    # a word*W ,result is same as:
    # [1,0,0]*[w1,w4
    #          w2,w5
    #          w3,w6]
    # so one word embedding vector is [w1,w4]
    # 即: W[i][0],W[i][1]
    for i, label in enumerate(word_list):
        W, WT = model.parameters()
        x,y = float(W[i][0]), float(W[i][1])
        plt.scatter(x, y)
        plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
    plt.show()
  • 相关阅读:
    文件高级应用和函数基础
    字符编码,文件操作
    数据类型分类,深浅拷贝
    容器数据类型内置方法
    数字类型和字符串类型内置方法
    流程控制循环
    python 运算和流程控制
    【MySQL】SQL教程
    【MySQL】数据库字段类型
    【java】HashSet
  • 原文地址:https://www.cnblogs.com/dhName/p/11825509.html
Copyright © 2011-2022 走看看