zoukankan      html  css  js  c++  java
  • pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》

    论文  《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。

    论文地址: 666666

    模型图:

      

     模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial

     1 # -*- coding: utf-8 -*-
     2 # @time : 2019/11/9  13:55
     3 
     4 import numpy as np
     5 import torch
     6 import torch.nn as nn
     7 import torch.optim as optim
     8 from torch.autograd import Variable
     9 import torch.nn.functional as F
    10 
    11 dtype = torch.FloatTensor
    12 
    13 # Text-CNN Parameter
    14 embedding_size = 2 # n-gram
    15 sequence_length = 3
    16 num_classes = 2  # 0 or 1
    17 filter_sizes = [2, 2, 2] # n-gram window
    18 num_filters = 3
    19 
    20 # 3 words sentences (=sequence_length is 3)
    21 sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
    22 labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.
    23 
    24 word_list = " ".join(sentences).split()
    25 word_list = list(set(word_list))
    26 word_dict = {w: i for i, w in enumerate(word_list)}
    27 vocab_size = len(word_dict)
    28 
    29 inputs = []
    30 for sen in sentences:
    31     inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
    32 
    33 targets = []
    34 for out in labels:
    35     targets.append(out) # To using Torch Softmax Loss function
    36 
    37 input_batch = Variable(torch.LongTensor(inputs))
    38 target_batch = Variable(torch.LongTensor(targets))
    39 
    40 
    41 class TextCNN(nn.Module):
    42     def __init__(self):
    43         super(TextCNN, self).__init__()
    44 
    45         self.num_filters_total = num_filters * len(filter_sizes)
    46         self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
    47         self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
    48         self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)
    49 
    50     def forward(self, X):
    51         embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
    52         embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]
    53 
    54         pooled_outputs = []
    55         for filter_size in filter_sizes:
    56             # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
    57             conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
    58             h = F.relu(conv)
    59             # mp : ((filter_height, filter_width))
    60             mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
    61             # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
    62             pooled = mp(h).permute(0, 3, 2, 1)
    63             pooled_outputs.append(pooled)
    64 
    65         h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
    66         h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]
    67 
    68         model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
    69         return model
    70 
    71 model = TextCNN()
    72 
    73 criterion = nn.CrossEntropyLoss()
    74 optimizer = optim.Adam(model.parameters(), lr=0.001)
    75 
    76 # Training
    77 for epoch in range(5000):
    78     optimizer.zero_grad()
    79     output = model(input_batch)
    80 
    81     # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
    82     loss = criterion(output, target_batch)
    83     if (epoch + 1) % 1000 == 0:
    84         print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    85 
    86     loss.backward()
    87     optimizer.step()
    88 
    89 # Test
    90 test_text = 'sorry hate you'
    91 tests = [np.asarray([word_dict[n] for n in test_text.split()])]
    92 test_batch = Variable(torch.LongTensor(tests))
    93 
    94 # Predict
    95 predict = model(test_batch).data.max(1, keepdim=True)[1]
    96 if predict[0][0] == 0:
    97     print(test_text,"is Bad Mean...")
    98 else:
    99     print(test_text,"is Good Mean!!")
  • 相关阅读:
    java 万能转换器 输入SQL 直接得到ArrayList
    社交原理
    意志力和自律
    windows phone 8.1 让项目开启蓝牙genericAttributeProfile
    C# JSON和对象之间互相转换
    QTC++监控USB插拔
    英语通假字
    #ifdef 支持Mac #ifndef 支持Windows #if defined (Q_OS_WIN) 应该可以再两个系统通用
    Qt5.3.2 在MAC yosemite下编译出错 Could not resolve SDK path
    Mac 用Ctr+C复制,Ctr+V 粘贴
  • 原文地址:https://www.cnblogs.com/dhName/p/11826039.html
Copyright © 2011-2022 走看看