zoukankan      html  css  js  c++  java
  • Pytorch LSTM 词性判断

    首先,我们定义好一个LSTM网络,然后给出一个句子,每个句子都有很多个词构成,每个词可以用一个词向量表示,这样一句话就可以形成一个序列,我们将这个序列依次传入LSTM,然后就可以得到与序列等长的输出,每个输出都表示的是一种词性,比如名词,动词之类的,还是一种分类问题,每个单词都属于几种词性中的一种。

    我们可以思考一下为什么LSTM在这个问题里面起着重要的作用。如果我们完全孤立的对一个词做词性的判断这样我们需要特别高维的词向量,但是对于LSTM,它有着一个记忆的特性,这样我们就能够通过这个单词前面记忆的一些词语来对其做一个判断,比如前面如果是my,那么他紧跟的词有很大可能就是一个名词,这样就能够充分的利用上文来做这个问题。

    同时我们还可以通过引入字符来增强表达,什么意思呢?也就是说一个单词有一些前缀和后缀,比如-ly这种后缀很大可能是一个副词,这样我们就能够在字符水平得到一个词性判断的更好结果。

    具体怎么做呢?还是用LSTM。每个单词有不同的字母组成,比如 apple 由a p p l e构成,我们同样给这些字符词向量,这样形成了一个长度为5的序列,然后传入另外一个LSTM网络,只取最后输出的状态层作为它的一种字符表达,我们并不需要关心到底提取出来的字符表达是什么样的,在learning的过程中这些都是会被更新的参数,使得最终我们能够正确预测。

      1 import torch
      2 import torch.nn.functional as F
      3 from torch import nn, optim
      4 from torch.autograd import Variable
      5 
      6 training_data = [("The dog ate the apple".split(),
      7                   ["DET", "NN", "V", "DET", "NN"]),
      8                  ("Everybody read that book".split(), ["NN", "V", "DET",
      9                                                        "NN"])]
     10 # 每个单词就用一个数字表示,每种词性也用一个数字表示
     11 word_to_idx = {}
     12 tag_to_idx = {}
     13 for context, tag in training_data:
     14     for word in context:
     15         if word not in word_to_idx:
     16             # 对词进行编码
     17             word_to_idx[word] = len(word_to_idx)
     18     for label in tag:
     19         if label not in tag_to_idx:
     20             # 对词性编码
     21             tag_to_idx[label] = len(tag_to_idx)
     22 alphabet = 'abcdefghijklmnopqrstuvwxyz'
     23 character_to_idx = {}
     24 for i in range(len(alphabet)):
     25     # 对字母编码
     26     character_to_idx[alphabet[i]] = i
     27 
     28 # 字符LSTM
     29 class CharLSTM(nn.Module):
     30     def __init__(self, n_char, char_dim, char_hidden):
     31         super(CharLSTM, self).__init__()
     32         self.char_embedding = nn.Embedding(n_char, char_dim)
     33         self.char_lstm = nn.LSTM(char_dim, char_hidden, batch_first=True)
     34 
     35     def forward(self, x):
     36         x = self.char_embedding(x)
     37         _, h = self.char_lstm(x)
     38         # 取隐层
     39         return h[0]
     40 
     41 
     42 class LSTMTagger(nn.Module):
     43     def __init__(self, n_word, n_char, char_dim, n_dim, char_hidden, n_hidden,
     44                  n_tag):
     45         super(LSTMTagger, self).__init__()
     46         self.word_embedding = nn.Embedding(n_word, n_dim)
     47         self.char_lstm = CharLSTM(n_char, char_dim, char_hidden)
     48         self.lstm = nn.LSTM(n_dim + char_hidden, n_hidden, batch_first=True)
     49         self.linear1 = nn.Linear(n_hidden, n_tag)
     50 
     51     def forward(self, x, word):
     52         char = torch.FloatTensor()
     53         for each in word:
     54             char_list = []
     55             for letter in each:
     56                 # 对词进行字母编码
     57                 char_list.append(character_to_idx[letter.lower()])
     58             char_list = torch.LongTensor(char_list)
     59             char_list = char_list.unsqueeze(0)
     60             if torch.cuda.is_available():
     61                 tempchar = self.char_lstm(Variable(char_list).cuda())
     62             else:
     63                 tempchar = self.char_lstm(Variable(char_list))
     64             tempchar = tempchar.squeeze(0)
     65             char = torch.cat((char, tempchar.cpu().data), 0)
     66         if torch.cuda.is_available():
     67             char = char.cuda()
     68         char = Variable(char)
     69         x = self.word_embedding(x)
     70         x = torch.cat((x, char), 1) # char编码与word编码cat
     71         x = x.unsqueeze(0)
     72         # 取输出层 句长*n_hidden
     73         x, _ = self.lstm(x)
     74         x = x.squeeze(0)
     75         x = self.linear1(x)
     76         y = F.log_softmax(x)
     77         return y
     78 
     79 
     80 model = LSTMTagger(
     81     len(word_to_idx), len(character_to_idx), 10, 100, 50, 128, len(tag_to_idx))
     82 if torch.cuda.is_available():
     83     model = model.cuda()
     84 criterion = nn.CrossEntropyLoss()
     85 optimizer = optim.SGD(model.parameters(), lr=1e-2)
     86 
     87 
     88 def make_sequence(x, dic):
     89     idx = [dic[i] for i in x]
     90     idx = Variable(torch.LongTensor(idx))
     91     return idx
     92 
     93 
     94 for epoch in range(300):
     95     print('*' * 10)
     96     print('epoch {}'.format(epoch + 1))
     97     running_loss = 0
     98     for data in training_data:
     99         word, tag = data
    100         word_list = make_sequence(word, word_to_idx)
    101         tag = make_sequence(tag, tag_to_idx)
    102         if torch.cuda.is_available():
    103             word_list = word_list.cuda()
    104             tag = tag.cuda()
    105         # forward
    106         out = model(word_list, word)
    107         loss = criterion(out, tag)
    108         running_loss += loss.data[0]
    109         # backward 三步常规操作
    110         optimizer.zero_grad()
    111         loss.backward()
    112         optimizer.step()
    113     print('Loss: {}'.format(running_loss / len(data)))
    114 print()
    115 input = make_sequence("Everybody ate the apple".split(), word_to_idx)
    116 if torch.cuda.is_available():
    117     input = input.cuda()
    118 model.eval() #对dropout和batch normalization的操作在训练和测试的时候是不一样
    119 out = model(input, "Everybody ate the apple".split())
    120 print(out)

    首先n_word 和 n_dim来定义单词的词向量维度,n_char和char_dim来定义字符的词向量维度,char_hidden表示CharLSTM输出的维度,n_hidden表示每个单词作为序列输入的LSTM输出维度,最后n_tag表示输出的词性的种类。

    接着开始前向传播,不仅要传入一个编码之后的句子,同时还需要传入原本的单词,因为需要对字符做一个LSTM,所以传入的参数多了一个word_data表示一个句子的所有单词。

    然后就是将每个单词传入CharLSTM,得到的结果和单词的词向量拼在一起形成一个新的输入,将输入传入LSTM里面,得到输出,最后接一个全连接层,将输出维数定义为label的数目。

    特别要注意里面有一些unsqueeze(增维)和squeeze(降维)是因为LSTM的输入要求要带上batch_size(这里是1),torch.cat里面0和1分别表示沿着行和列来拼接。

    预测一下 Everybody ate the apple 这句话每个词的词性,一共有3种词性,DET,NN,V。最后得到的结果为:

     一共有4行,每行里面取最大的,那么第一个词的词性就是NN,第二个词是V,第三个词是DET,第四个词是NN。这个是相符的。

    参考自:https://sherlockliao.github.io/2017/06/05/lstm%20language/

  • 相关阅读:
    synchronized对比cas
    java 数据集合类
    【转载】S2SH
    【转载】Solr4+IKAnalyzer的安装配置
    【转】基于CXF Java 搭建Web Service (Restful Web Service与基于SOAP的Web Service混合方案)
    【转载】solr初体验
    【转载】CSS 盒子模型
    【转载】div层调整zindex属性无效原因分析及解决方法
    【转载】 IE/Firefox每次刷新时自动检查网页更新,无需手动清空缓存的设置方法
    mysql ODBC connector相关问题
  • 原文地址:https://www.cnblogs.com/demian/p/8007199.html
Copyright © 2011-2022 走看看