zoukankan      html  css  js  c++  java
  • 使用Transformers实现文本分类

    github: https://github.com/haibincoder/NlpSummary/tree/master/torchcode/classification

    1. 使用TextCNN实现文本分类
    2. 使用LSTM实现文本分类
    3. 使用Transformers实现文本分类
    import copy
    
    from torch import nn
    import torch.nn.functional as F
    import torch
    import math
    
    class Config(object):
    
        """配置参数"""
        def __init__(self, vocab_size, embed_dim, label_num, max_length=32):
            self.embedding_pretrained = None
            self.num_classes = label_num                                    # 类别数
            self.vocab_size = vocab_size                                       # 词表大小,在运行时赋值
            self.embed_dim = embed_dim                                      # 字向量维度
            self.num_head = 5
            self.dropout = 0.1
            self.hidden = 512
            self.num_encoder = 2
            self.max_length = max_length
            self.lr = 1e-3
    
    
    class Model(nn.Module):
        def __init__(self, config):
            super().__init__()
            if config.embedding_pretrained is not None:
                self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
            else:
                self.embedding = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.vocab_size - 1)
    
            self.encoder = Encoder(config.embed_dim, config.num_head, config.hidden, config.dropout)
            self.encoders = nn.ModuleList([
                copy.deepcopy(self.encoder)
                for _ in range(config.num_encoder)])
    
            self.fc1 = nn.Linear(config.max_length * config.embed_dim, config.num_classes)
    
        def forward(self, x):
            out = self.embedding(x)
            for encoder in self.encoders:
                out = encoder(out)
            out = out.view(out.size(0), -1)
            # out = torch.mean(out, 1)
            out = self.fc1(out)
            return out
    
    
    class Encoder(nn.Module):
        def __init__(self, embed_dim, num_head, hidden, dropout=0.0):
            super().__init__()
            self.attention = MultiHeadAttention(embed_dim, num_head, dropout)
            self.feed_forward = Position_wise_Feed_Forward(embed_dim, hidden, dropout)
        def forward(self, x):
            out = self.attention(x)
            out = self.feed_forward(out)
            return out
    
    
    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_dim, num_head, dropout=0.0):
            super().__init__()
            self.num_head = num_head
            assert embed_dim % num_head == 0, 'head num error'
            self.dim_head = embed_dim // num_head
            self.fc_q = nn.Linear(embed_dim, embed_dim)
            self.fc_k = nn.Linear(embed_dim, embed_dim)
            self.fc_v = nn.Linear(embed_dim, embed_dim)
            self.attention = Attention()
            self.fc = nn.Linear(embed_dim, embed_dim)
            self.dropout = nn.Dropout(dropout)
            self.layer_norm = nn.LayerNorm(embed_dim)
    
        def forward(self, x):
            batch_size = x.size(0)
            Q = self.fc_q(x)
            K = self.fc_k(x)
            V = self.fc_v(x)
            Q = Q.view(batch_size * self.num_head, -1, self.dim_head)
            K = K.view(batch_size * self.num_head, -1, self.dim_head)
            V = V.view(batch_size * self.num_head, -1, self.dim_head)
            context = self.attention(Q, K, V)
            context = context.view(batch_size, -1, self.dim_head * self.num_head)
            out = self.fc(context)
            out = self.dropout(out)
            # 残差
            out = out + x
            out = self.layer_norm(out)
            return out
    
    '''
    attention计算
    '''
    class Attention(nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, query, key, value):
            k = query.size(-1)
            result = torch.matmul(query, key.transpose(-2, -1))
            score = result / math.sqrt(k)
            softmax_result = torch.softmax(score, dim=-1)
            result = torch.matmul(softmax_result, value)
            return result
    
    '''
    这里对应transformers encoder的Feed forward
    '''
    class Position_wise_Feed_Forward(nn.Module):
        def __init__(self, dim_model, hidden, dropout=0.0):
            super().__init__()
            self.fc1 = nn.Linear(dim_model, hidden)
            self.fc2 = nn.Linear(hidden, dim_model)
            self.dropout = nn.Dropout(dropout)
            self.layer_norm = nn.LayerNorm(dim_model)
    
        def forward(self, x):
            out = self.fc1(x)
            out = F.relu(out)
            out = self.fc2(out)
            out = self.dropout(out)
            out = out + x  # 残差连接
            out = self.layer_norm(out)
            return out
    
    时间会记录下一切。
  • 相关阅读:
    Kafka日志及Topic数据清理
    python
    kotlin集合操作
    tomcat 下配置 可 调试
    linux 安装nexus3
    启动 idea 编译报错 kotlin
    nginx 增加 lua模块
    logstash配合filebeat监控tomcat日志
    redis 高级特性 不要太好用
    SpringBoot与Docker1
  • 原文地址:https://www.cnblogs.com/bincoding/p/14416517.html
Copyright © 2011-2022 走看看