zoukankan      html  css  js  c++  java
  • 动手学深度学习--TextCNN

    TextCNN--文本情感分析

    将文本当做一维图像,从而可以用一维卷积神经网络来捕捉邻近词之间的关联

    一维卷积层的工作原理

    与⼆维卷积层⼀样,⼀维卷积层使⽤⼀维的互相关运算。在⼀维互相关运算中,卷积窗⼝从输⼊数组的最左⽅开始,按从左往右的顺序,依次在输⼊数组上滑动。当卷积窗⼝滑动到某⼀位置时,窗⼝中的输⼊⼦数组与核数组按元素相乘并求和,得到输出数组中相应位置的元素。

     1 def corr1d(X, K):
     2     w = K.shape[0]
     3     Y = torch.zeros((X.shape[0] - w + 1))
     4     for i in range(Y.shape[0]):
     5         Y[i] = (X[i: i + w] * K).sum()
     6     return Y
     7 
     8 #测试
     9 X, K = torch.tensor([0, 1, 2, 3, 4, 5, 6]), torch.tensor([1, 2])
    10 corr1d(X, K)  #output:  tensor([ 2.,  5.,  8., 11., 14., 17.])

    多输⼊通道的⼀维互相关运算也与多输⼊通道的⼆维互相关运算类似:在每个通道上,将核与相应的输⼊做⼀维互相关运算,并将通道之间的结果相加得到输出结果。

     1 def corr1d_multi_in(X, K):
     2     # 首先沿着X和K的第0维(通道维)遍历并计算一维互相关结果,然后将所有结果堆叠起来沿第0维累加
     3     return torch.stack([corr1d(x, k) for x, k in zip(X, K)]).sum(dim=0)
     4 
     5 # 测试
     6 X = torch.tensor([[0,1,2,3,4,5,6],
     7                   [1,2,3,4,5,6,7],
     8                   [2,3,4,5,6,7,8]])
     9 K = torch.tensor([[1,2],[3, 4], [-1,-3]])
    10 corr1d_multi_in(X, K)  # output: tensor([ 2.,  8., 14., 20., 26., 32.])
    View Code

    时序最大池化层

    textCNN中使⽤的时序最⼤池化(max-over-time pooling)层实际上对应⼀维全局最⼤池化层:假设输⼊包含多个通道,各通道由不同时间步上的数值组成,各通道的输出即该通道所有时间步中最⼤的数值。因此,时序最⼤池化层的输⼊在各个通道上的时间步数可以不同。由于时序最⼤池化的主要⽬的是抓取时序中最᯿要的特征,它通常能使模型不受⼈为添加字符的影响。

    1 class GlobalMaxPool1d(nn.Module):
    2     def __init__(self):
    3         super(GlobalMaxPool1d, self).__init__()
    4     def forward(self, x):
    5         # x shape: (batch_size, channel, seq_len)
    6         return F.max_pool1d(x, kernel_size=x.shape[2])  # shape:(batch_size, channel, 1)

    TextCNN模型

    textCNN模型主要使⽤了⼀维卷积层和时序最⼤池化层。假设输⼊的⽂本序列由 个词组成,每个词⽤维的词向量表示。那么输⼊样本的宽为 ,⾼为1,输⼊通道数为 。 textCNN的计算主要分为以下⼏步。
      1. 定义多个⼀维卷积核,并使⽤这些卷积核对输⼊分别做卷积计算。宽度不同的卷积核可能会捕捉到不同个数的相邻词的相关性。
      2. 对输出的所有通道分别做时序最⼤池化,再将这些通道的池化输出值连结为向量。
      3. 通过全连接层将连结后的向量变换为有关各类别的输出。这⼀步可以使⽤丢弃层应对过拟合。

     1 class TextCNN(nn.Module):
     2     def __init__(self, vocab, embed_size, kernel_sizes, num_channels):
     3         super(TextCNN, self).__init__()
     4         self.embedding = nn.Embedding(len(vocab), embed_size)
     5         # 不参与训练的嵌入层
     6         self.constant_embedding = nn.Embedding(len(vocab), embed_size)
     7         self.dropout = nn.Dropout(0.5)
     8         self.decoder = nn.Linear(sum(num_channels), 2)
     9         # 时序最大池化层没有权重,所以可共用一个实例
    10         self.pool = GlobalMaxPool1d()
    11         self.convs = nn.ModuleList()  # 创建多个一维卷积层
    12         for c, k in zip(num_channels, kernel_sizes):
    13             self.convs.append(nn.Conv1d(in_channels = 2 * embed_size,
    14                                         out_channels = c,
    15                                         kernel_size = k))
    16             
    17     def forward(self, inputs):
    18         # 将两个形状是(批量大小,词数,词向量维度)的嵌入层的输出按词向量连接
    19         embeddings = torch.cat((
    20             self.embedding(inputs),
    21             self.constant_embedding(inputs)), dim=2)  # (batch_size, seq_len, 2*embed_size)
    22         # 根据Conv1d要求的输入格式,将词向量维,即一维卷积层的通道维变换到前一维
    23         embeddings = embeddings.permute(0, 2, 1)
    24         # 对于每个一维卷积层,在时序最大池化后会得到一个形状为(批量大小,通道大小,1)
    25         # 的Tensor.使用flatten函数去掉最后一维,然后在通道维上连接
    26         encoding = torch.cat([self.pool(F.relu(conv(embedding))).squeeze(-1) for conv in self.convs], dim=1)
    27         # 应用丢弃法后使用全连接层得到输出
    28         outputs = self.decoder(self.dropout(encoding))
    29         return outputs
    30 
    31 embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
    32 net = TextCNN(vocab, embed_size, kernel_sizes, nums_channels)
    View Code

    OK,记录一下模型,以上内容都来自《动手学深度学习》这本书。

  • 相关阅读:
    BZOJ5296 [CQOI2018] 破解D-H协议 【数学】【BSGS】
    Codeforces963C Frequency of String 【字符串】【AC自动机】
    Codeforces962F Simple Cycles Edges 【双连通分量】【dfs树】
    Hello World
    Codeforces963C Cutting Rectangle 【数学】
    BZOJ5203 [NEERC2017 Northern] Grand Test 【dfs树】【构造】
    20160422 --Switch…case 总结; 递归算法
    20160421字符串类型;日期时间类型数学类型
    20160420冒泡排序和查找
    20160419 while练习,复习
  • 原文地址:https://www.cnblogs.com/harbin-ho/p/12026082.html
Copyright © 2011-2022 走看看