zoukankan      html  css  js  c++  java
  • 使用TextCNN实现文本分类

    github: https://github.com/haibincoder/NlpSummary/tree/master/torchcode/classification

    1. 使用TextCNN实现文本分类
    2. 使用LSTM实现文本分类
    3. 使用Transformers实现文本分类
    # model
    # coding: UTF-8
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    
    
    class Config(object):
    
        """配置参数"""
        def __init__(self, vocab_size, embed_dim, label_num):
            self.model_name = 'TextCNN'
            self.embedding_pretrained = None
            self.dropout = 0.2
            self.num_classes = label_num                                    # 类别数
            self.n_vocab = vocab_size                                       # 词表大小,在运行时赋值
            self.embed_dim = embed_dim                                          # 字向量维度
            self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸
            self.num_filters = 256                                          # 卷积核数量(channels数)
            self.lr = 1e-3
    
    
    class Model(nn.Module):
        def __init__(self, config):
            super(Model, self).__init__()
            if config.embedding_pretrained is not None:
                self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
            else:
                self.embedding = nn.Embedding(config.n_vocab, config.embed_dim, padding_idx=config.n_vocab - 1)
            # self.convs = nn.ModuleList(
            #     [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])
            self.conv1 = nn.Conv2d(1, config.num_filters, (2, config.embed_dim))
            self.conv2 = nn.Conv2d(1, config.num_filters, (3, config.embed_dim))
            self.conv3 = nn.Conv2d(1, config.num_filters, (4, config.embed_dim))
            self.dropout = nn.Dropout(config.dropout)
            self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)
    
        def conv_and_pool(self, x, conv):
            x = F.relu(conv(x)).squeeze(3)
            x = F.max_pool1d(x, x.size(2)).squeeze(2)
            return x
    
        def forward(self, input):
            out = self.embedding(input)
            out = out.unsqueeze(1)
            # out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
            conv1 = self.conv_and_pool(out, self.conv1)
            conv2 = self.conv_and_pool(out, self.conv2)
            conv3 = self.conv_and_pool(out, self.conv3)
            out = torch.cat((conv1, conv2, conv3), 1)
    
            out = self.dropout(out)
            out = self.fc(out)
            return out
    
    from importlib import import_module
    
    import torch
    from sklearn import datasets, metrics
    from sklearn.model_selection import train_test_split
    from torch.utils.data import Dataset, DataLoader
    import torch.nn.functional as F
    from tqdm import tqdm
    import numpy as np
    
    vocab_size = 5000
    batch_size = 128
    max_length = 32
    embed_dim = 300
    label_num = 10
    epoch = 5
    train_path = '../../data/THUCNews/data/train.txt'
    dev_path = '../../data/THUCNews/data/dev.txt'
    vocab_path = '../../data/THUCNews/data/vocab.txt'
    
    output_path = 'output/'
    
    
    def get_data(path):
        input_vocab = open(vocab_path, 'r', encoding='utf-8')
        vocabs = {}
        for item in input_vocab.readlines():
            word, wordid = item.replace('
    ', '').split('	')
            vocabs[word] = int(wordid)
        input_data = open(path, 'r', encoding='utf-8')
        x = []
        y = []
        for item in input_data.readlines():
            sen, label = item.replace('
    ', '').split('	')
            tmp = []
            for item_char in sen:
                if item_char in vocabs:
                    tmp.append(vocabs[item_char])
                else:
                    tmp.append(1)
                if len(tmp) >= max_length:
                    break
            x.append(tmp)
            y.append(int(label))
    
        # padding
        for item in x:
            if len(item) < max_length:
                item += [0] * (max_length - len(item))
    
        label_num = len(set(y))
        # x_train, x_test, y_train, y_test = train_test_split(np.array(x), np.array(y), test_size=0.2)
        x = np.array(x)
        print(x.shape)
        y = np.array(y)
        return x, y, label_num
    
    
    class DealDataset(Dataset):
        def __init__(self, x_train, y_train, device):
            self.x_data = torch.from_numpy(x_train).long().to(device)
            self.y_data = torch.from_numpy(y_train).long().to(device)
            self.len = x_train.shape[0]
    
        def __getitem__(self, index):
            return self.x_data[index], self.y_data[index]
    
        def __len__(self):
            return self.len
    
    
    def evaluate(model, dataloader_dev):
        model.eval()
        predict_all = np.array([], dtype=int)
        labels_all = np.array([], dtype=int)
        with torch.no_grad():
            for datas, labels in dataloader_dev:
                output = model(datas)
                predic = torch.max(output.data, 1)[1].cpu()
                predict_all = np.append(predict_all, predic)
                labels_all = np.append(labels_all, labels)
                if len(predict_all) > 1000:
                    break
        acc = metrics.accuracy_score(labels_all, predict_all)
        return acc
    
    
    if __name__ == "__main__":
        debug = False
        # 相对路径 + modelName(TextCNN、TextLSTM)
        model_name = 'Transformer'
        module = import_module(model_name)
        config = module.Config(vocab_size, embed_dim, label_num)
    
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = module.Model(config).to(device)
        if debug:
            # 维度:batch_size * max_length, 数值:0~200之间的整数,每一行表示wordid
            inputs = torch.randint(0, 200, (batch_size, max_length))
            # 维度:batch_size * 1, 数值:0~2之间的整数,维度扩充1,和input对应
            labels = torch.randint(0, 2, (batch_size, 1)).squeeze(0)
            print(model(inputs))
        else:
            x_train, y_train, label_num = get_data(train_path)
            dataset = DealDataset(x_train, y_train, device)
            dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
            x_dev, y_dev, _ = get_data(dev_path)
            dataset_dev = DealDataset(x_dev, y_dev, device)
            dataloader_dev = DataLoader(dataset=dataset_dev, batch_size=batch_size, shuffle=True)
    
            optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
            model.train()
            best_acc = 0
            for i in range(epoch):
                index = 0
                for datas, labels in tqdm(dataloader):
                    model.zero_grad()
                    output = model(datas)
                    loss = F.cross_entropy(output, labels)
                    loss.backward()
                    optimizer.step()
                    index += 1
                    if index % 50 == 0:
                        # 每多少轮输出在训练集和验证集上的效果
                        true = labels.data.cpu()
                        predic = torch.max(output.data, 1)[1].cpu()
                        train_acc = metrics.accuracy_score(true, predic)
                        dev_acc = evaluate(model, dataloader_dev)
                        print(f'epoch:{i} batch:{index} loss:{loss} train_acc:{train_acc} dev_acc:{dev_acc}')
                        # if dev_acc > best_acc:
                        #     torch.save(model, f'{output_path}/{model_name}/model.pt')
                        model.train()
    
            print('train finish')
    
    
    时间会记录下一切。
  • 相关阅读:
    Bootstrap导航组件
    Bootstrap输入框组
    Bootstrap按钮式下拉菜单
    Bootstrap按钮组
    Bootstrap下拉菜单
    Bootstrap 中的 aria-label 和 aria-labelledby
    js 在函数中遇到的this指向问题
    js中 clientWidth offsetWidth scrollWidth等区别
    小程序--授权封装
    小程序--分享功能
  • 原文地址:https://www.cnblogs.com/bincoding/p/14416473.html
Copyright © 2011-2022 走看看