zoukankan      html  css  js  c++  java
  • 深度学习与Pytorch入门实战(十六)情感分类实战(基于IMDB数据集)

    笔记摘抄

    提前安装torchtext和scapy,运行下面语句(压缩包地址链接:https://pan.baidu.com/s/1_syic9B-SXKQvkvHlEf78w 提取码:ahh3):

    pip install torchtext
    
    pip install scapy
    
    pip install 你的地址en_core_web_md-2.2.5.tar.gz  
    
    • 在torchtext中使用spacy时,由于field的默认属性是tokenizer_language='en'

    • 当使用 en_core_web_md 时要改 field.py文件中 创建的field属性为tokenizer_language='en_core_web_md',且data.Field()中的参数也要改为tokenizer_language='en_core_web_md'

    1. 加载数据

    分类任务中,我们所需要接触到的数据有文本字符串和两种情感,"pos"或者"neg"。

    • Field的参数制定了数据会被怎样处理。

    • 我们使用TEXT field来定义如何处理电影评论,使用LABEL field来处理两个情感类别。

    • 我们的TEXT field带有tokenize='spacy',这表示我们会用spaCy tokenizer来tokenize英文句子。如果我们不特别声明tokenize这个参数,那么默认的分词方法是使用空格。

    • 安装spaCy

    pip install -U spacy
    python -m spacy download en
    

    1.1 分割训练集测试集

    import numpy as np
    import torch
    from torch import nn, optim
    from torchtext import data, datasets
    
    # 为CPU设置随机种子
    torch.manual_seed(123)
    
    # 两个Field对象定义字段的处理方法(文本字段、标签字段)
    TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_md')  # 分词
    LABEL = data.LabelField(dtype=torch.float)
    
    • TorchText支持很多常见的自然语言处理数据集。

    • 下面的代码会自动下载IMDb数据集,然后分成train/test两个torchtext.datasets类别。数据被前面的Fields处理。IMDb数据集一共有50000电影评论,每个评论都被标注为正面的或负面的。

    # from torchtext import data, datasets
    # IMDB共50000影评,包含正面和负面两个类别。数据被前面的Field处理
    # 按照(TEXT, LABEL) 分割成 训练集,测试集
    train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
    
    print('len of train data:', len(train_data))        # 25000
    print('len of test data:', len(test_data))          # 25000
    
    # torchtext.data.Example : 用来表示一个样本,数据+标签
    print(train_data.examples[15].text)                 # 文本:句子的单词列表
    print(train_data.examples[15].label)                # 标签: 积极
    
    len of train data: 25000
    len of test data: 25000
    ['Like', 'one', 'of', 'the', 'previous', 'commenters', 'said', ',', 'this', 'had', 'the', 'foundations', 'of', 'a', 'great', 'movie', 'but', 'something', 'happened', 'on', 'the', 'way', 'to', 'delivery', '.', 'Such', 'a', 'waste', 'because', 'Collette', "'s", 'performance', 'was', 'eerie', 'and', 'Williams', 'was', 'believable', '.', 'I', 'just', 'kept', 'waiting', 'for', 'it', 'to', 'get', 'better', '.', 'I', 'do', "n't", 'think', 'it', 'was', 'bad', 'editing', 'or', 'needed', 'another', 'director', ',', 'it', 'could', 'have', 'just', 'been', 'the', 'film', '.', 'It', 'came', 'across', 'as', 'a', 'Canadian', 'movie', ',', 'something', 'like', 'the', 'first', 'few', 'seasons', 'of', 'X', '-', 'Files', '.', 'Not', 'cheap', ',', 'just', 'hokey', '.', 'Also', ',', 'it', 'needed', 'a', 'little', 'more', 'suspense', '.', 'Something', 'that', 'makes', 'you', 'jump', 'off', 'your', 'seat', '.', 'The', 'movie', 'reached', 'that', 'moment', 'then', 'faded', 'away', ';', 'kind', 'of', 'like', 'a', 'false', 'climax', '.', 'I', 'can', 'see', 'how', 'being', 'too', 'suspenseful', 'would', 'have', 'taken', 'away', 'from', 'the', '"', 'reality', '"', 'of', 'the', 'story', 'but', 'I', 'thought', 'that', 'part', 'was', 'reached', 'when', 'Gabriel', 'was', 'in', 'the', 'hospital', 'looking', 'for', 'the', 'boy', '.', 'This', 'movie', 'needs', 'to', 'have', 'a', 'Director', "'s", 'cut', 'that', 'tries', 'to', 'fix', 'these', 'problems', '.']
    pos
    
    • 由于我们现在只有train/test这两个分类,所以我们需要创建一个新的validation set。我们可以使用.split()创建新的分类。

    • 默认的数据分割是 70、30,如果我们声明split_ratio,可以改变split之间的比例,split_ratio=0.8表示80%的数据是训练集,20%是验证集。

    • 我们还声明random_state这个参数,确保我们每次分割的数据集都是一样的。

    import random
    SEED = 1234
    train_data, valid_data = train_data.split(random_state=random.seed(SEED))
    

    检查一下现在每个部分有多少条数据。

    print(f'Number of training examples: {len(train_data)}')
    print(f'Number of validation examples: {len(valid_data)}')
    print(f'Number of testing examples: {len(test_data)}')
    
    Number of training examples: 17500
    Number of validation examples: 7500
    Number of testing examples: 25000
    

    1.2 创建vocabulary

    • vocabulary把每个单词一一映射到一个数字。

    • 使用10k个单词来构建单词表(用max_size这个参数可以设定)

    • 所有其他的单词都用<unk>来表示。

    • 词典中应当有10002个单词,且有两个label,可以通过TEXT.vocabTEXT.label查询,可以直接用stoi(stringtoint) 或者 itos(inttostring) 来查看单词表。

    TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d') # unk_init=torch.Tensor.normal_
    LABEL.build_vocab(train_data)
    
    print(len(TEXT.vocab))             # 10002
    print(TEXT.vocab.itos[:12])        # ['<unk>', '<pad>', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is', 'in', 'I']
    print(TEXT.vocab.stoi['and'])      # 5
    print(LABEL.vocab.stoi)            # defaultdict(None, {'neg': 0, 'pos': 1})
    
    ['<unk>', '<pad>', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is', 'in', 'I']
    5
    defaultdict(<function _default_unk_index at 0x7f44ffd3ed90>, {'neg': 0, 'pos': 1})
    
    • 当我们把句子传进模型的时候,我们是按照一个个 batch 穿进去的。

    • 也就是说,我们一次传入了好几个句子,而且每个batch中的句子必须是相同的长度。为了确保句子的长度相同,TorchText会把短的句子pad到和最长的句子等长。

    • 下面我们来看看训练数据集中最常见的单词。

    print(TEXT.vocab.freqs.most_common(20))
    
    [('the', 201455), (',', 192552), ('.', 164402), ('a', 108963), ('and', 108649), ('of', 100010), ('to', 92873), ('is', 76046), ('in', 60904), ('I', 54486), ('it', 53405), ('that', 49155), ('"', 43890), ("'s", 43151), ('this', 42454), ('-', 36769), ('/><br', 35511), ('was', 34990), ('as', 30324), ('with', 29691)]
    

    1.3 创建iteratiors

    • 每个itartion都会返回一个batch的examples。

    • 每个iterator中各有两部分:词(.text)和标签(.label),其中 text 全部转换成数字了

    • BucketIterator会把长度差不多的句子放到同一个batch中,确保每个batch中不出现太多的padding。

    • 这里因为pad比较少,所以把也当做了模型的输入进行训练。

    • 如果有GPU,还可以指定每个iteration返回的tensor 都在GPU上。

    batchsz = 30
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
                                    (train_data, valid_data, test_data),
                                    batch_size = batchsz,
                                    device = device,
                                    repeat = False
                                   )
    
    # for i, _ in enumerate(train_iterator):
    #     print(i)
    batch = next(iter(train_iterator))
    print(batch.text)
    print(batch.text.shape)
    print(batch.label.shape)
    
    tensor([[  25,   66, 1215,  ...,  471,   11, 1267],
            [ 132,    9, 2348,  ...,   42,  465,  298],
            [  19,    6, 1703,  ...,    3,  142, 1678],
            ...,
            [   1,    1,    1,  ...,    1,    1,    1],
            [   1,    1,    1,  ...,    1,    1,    1],
            [   1,    1,    1,  ...,    1,    1,    1]])
    torch.Size([1058, 64])   # 一个batch,64条数据,1058个参数
    torch.Size([64])
    
    print(TEXT.pad_token)
    PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
    print(PAD_IDX)
    mask = batch.text == PAD_IDX
    print(mask)
    
    <pad>
    1
    tensor([[False, False, False,  ..., False, False, False],
            [False, False, False,  ..., False, False, False],
            [False, False, False,  ..., False, False, False],
            ...,
            [ True,  True,  True,  ...,  True,  True,  True],
            [ True,  True,  True,  ...,  True,  True,  True],
            [ True,  True,  True,  ...,  True,  True,  True]])
    

    2. 定义模型

    class RNN(nn.Module):
    
      def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNN, self).__init__()
    
        # [0-10001] => [100]
        # 参数1:embedding个数(单词数), 参数2:embedding的维度(词向量维度)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # [100] => [256]
        # 双向LSTM,所以下面FC层使用 hidden_dim*2
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
                           bidirectional=True, dropout=0.5) 
        # [256*2] => [1]
        self.fc = nn.Linear(hidden_dim*2, 1)
        self.dropout = nn.Dropout(0.5)
    
      def forward(self, x):
        """
        x: [seq_len, b] vs [b, 3, 28, 28]
        """
        # [seq_len, b, 1] => [seq_len, b, 100]
        embedding = self.dropout(self.embedding(x))
    
        # output: [seq, b, hid_dim*2]
        # hidden/h: [num_layers*2, b, hid_dim]
        # cell/c: [num_layers*2, b, hid_dim]
        output, (hidden, cell) = self.rnn(embedding)
        # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
        # 双向,所以要把最后两个输出连接
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        # [b, hid_dim*2] => [b, 1]
        hidden = self.dropout(hidden)
        out = self.fc(hidden)
    
        return out
    
    • 使用 预训练过的embedding 来替换随机初始化

    • Tip:.copy_() 这种 带着下划线的函数 均代表 替换inplace

    rnn = RNN(len(TEXT.vocab), 100, 256)                          #词个数,词嵌入维度,输出维度
    
    pretrained_embedding = TEXT.vocab.vectors
    print('pretrained_embedding:', pretrained_embedding.shape)    # torch.Size([10002, 100])
    
    # 使用预训练过的embedding来替换随机初始化
    rnn.embedding.weight.data.copy_(pretrained_embedding)
    print('embedding layer inited.')
    
    pretrained_embedding: torch.Size([10002, 100])
    embedding layer inited.
    

    3. 训练模型

    • 首先定义模型和损失函数。
    optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
    
    # BCEWithLogitsLoss是针对二分类的CrossEntropy
    criteon = nn.BCEWithLogitsLoss()
    

    如果使用GPU加速,改成:

    # 优化函数
    optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
    
    # BCEWithLogitsLoss是针对二分类的CrossEntropy
    criteon = nn.BCEWithLogitsLoss().to(device)
    
    rnn = rnn.to(device)
    
    RNN(
      (embedding): Embedding(10002, 100)
      (rnn): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)
      (fc): Linear(in_features=512, out_features=1, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    
    • 定义一个函数用于计算准确率
    def binary_acc(preds, y):
    
        preds = torch.round(torch.sigmoid(preds))
        correct = torch.eq(preds, y).float()
        acc = correct.sum() / len(correct)
        return acc
    
    • 定义一个训练函数
    def train(rnn, iterator, optimizer, criteon):
        epoch_loss = 0
        epoch_acc = 0
        avg_acc = []
        rnn.train()   # 表示进入训练模式
    
        for i, batch in enumerate(iterator):
            # [seq, b] => [b, 1] => [b]
            # batch.text 就是上面forward函数的参数text,压缩维度是为了和batch.label维度一致
            pred = rnn(batch.text).squeeze(1)
    
            loss = criteon(pred, batch.label)
            # 计算每个batch的准确率
            acc = binary_acc(pred, batch.label).item()
            avg_acc.append(acc)
    
            optimizer.zero_grad()  # 清零梯度准备计算
            loss.backward()        # 反向传播
            optimizer.step()       # 更新训练参数
    
            if i % 10 == 0:
                print(i, acc)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
            
        avg_acc = np.array(avg_acc).mean()
        print('avg acc:', avg_acc)
        
        return epoch_loss / len(iterator), epoch_acc / len(iterator)   
    

    4. 评估模型

    • 定义一个评估函数,和训练函数高度重合

    • 区别是要把rnn.train()改为rnn.val(),不需要反向传播过程。

    def evaluate(rnn, iterator, criteon):
        avg_acc = []
        epoch_loss = 0
        epoch_acc = 0
        rnn.eval()         # 表示进入测试模式
    
        with torch.no_grad():
            for batch in iterator:
                pred = rnn(batch.text).squeeze(1)      # [b, 1] => [b]
                loss = criteon(pred, batch.label)
                acc = binary_acc(pred, batch.label).item()
                avg_acc.append(acc)
    
                epoch_loss += loss.item()
                epoch_acc += acc.item()
    
        avg_acc = np.array(avg_acc).mean()
        print('test acc:', avg_acc)
    
        return epoch_loss / len(iterator), epoch_acc / len(iterator)
    

    5. 运行

    
    best_valid_loss = float('inf')
    for epoch in range(10):
        # 训练模型
        train_loss, train_acc = train(rnn, train_iterator, optimizer, criteon)
        # 评估模型
        valid_loss, valid_acc = evaluate(rnn, valid_iterator, criteon)
    
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'wordavg-model.pt')
        
    
    view result
    0 0.8666667342185974
    10 0.9666666984558105
    20 0.8000000715255737
    30 0.8666667342185974
    40 0.8666667342185974
    50 0.8000000715255737
    60 0.9333333969116211
    70 0.7666667103767395
    80 0.9000000357627869
    90 0.8666667342185974
    100 0.9000000357627869
    110 0.7666667103767395
    120 0.8000000715255737
    130 0.9666666984558105
    140 0.8666667342185974
    150 0.9000000357627869
    160 0.9000000357627869
    170 0.9000000357627869
    180 0.8000000715255737
    190 0.8000000715255737
    200 0.9333333969116211
    210 0.9000000357627869
    220 0.9333333969116211
    230 0.8666667342185974
    240 0.9000000357627869
    250 0.7666667103767395
    260 0.9333333969116211
    270 0.9000000357627869
    280 0.8000000715255737
    290 0.8666667342185974
    300 0.9333333969116211
    310 0.7666667103767395
    320 0.9000000357627869
    330 0.9666666984558105
    340 0.9666666984558105
    350 0.8333333730697632
    360 0.9000000357627869
    370 0.8000000715255737
    380 0.9000000357627869
    390 0.8666667342185974
    400 0.8333333730697632
    410 0.9000000357627869
    420 0.9333333969116211
    430 0.8333333730697632
    440 0.8666667342185974
    450 0.8000000715255737
    460 0.9333333969116211
    470 0.8666667342185974
    480 0.9333333969116211
    490 0.9333333969116211
    500 0.9000000357627869
    510 0.8333333730697632
    520 0.8666667342185974
    530 0.9333333969116211
    540 0.9333333969116211
    550 0.7666667103767395
    560 0.8333333730697632
    570 0.9333333969116211
    580 0.9000000357627869
    590 0.9333333969116211
    600 0.9000000357627869
    610 0.8333333730697632
    620 0.7333333492279053
    630 0.8333333730697632
    640 0.8333333730697632
    650 0.9000000357627869
    660 0.9333333969116211
    670 0.8000000715255737
    680 0.9000000357627869
    690 0.9000000357627869
    700 0.9000000357627869
    710 0.9333333969116211
    720 0.8000000715255737
    730 0.9333333969116211
    740 0.9666666984558105
    750 0.9666666984558105
    760 0.9333333969116211
    770 0.8666667342185974
    780 0.8666667342185974
    790 0.8666667342185974
    800 0.9666666984558105
    810 0.9000000357627869
    820 0.9000000357627869
    830 0.9333333969116211
    avg acc: 0.8855715916454078
    test acc: 0.8775779855051201
    0 0.9000000357627869
    10 0.9666666984558105
    20 0.9000000357627869
    30 0.9000000357627869
    40 0.9666666984558105
    50 0.9666666984558105
    60 0.7666667103767395
    70 0.8666667342185974
    80 0.9333333969116211
    90 0.9000000357627869
    100 0.9333333969116211
    110 0.8666667342185974
    120 0.9000000357627869
    130 0.9000000357627869
    140 0.8666667342185974
    150 0.8333333730697632
    160 0.8333333730697632
    170 0.9333333969116211
    180 0.8333333730697632
    190 0.9000000357627869
    200 0.8666667342185974
    210 1.0
    220 1.0
    230 0.9666666984558105
    240 0.9000000357627869
    250 0.8000000715255737
    260 0.9333333969116211
    270 0.9666666984558105
    280 0.9333333969116211
    290 0.9666666984558105
    300 0.9000000357627869
    310 0.9333333969116211
    320 0.9333333969116211
    330 0.9666666984558105
    340 0.9666666984558105
    350 0.9666666984558105
    360 0.9333333969116211
    370 0.9666666984558105
    380 0.8333333730697632
    390 0.7333333492279053
    400 0.9000000357627869
    410 0.9000000357627869
    420 0.8000000715255737
    430 0.9333333969116211
    440 0.8666667342185974
    450 0.9333333969116211
    460 0.8333333730697632
    470 0.9333333969116211
    480 0.9333333969116211
    490 0.8000000715255737
    500 0.9666666984558105
    510 0.9000000357627869
    520 1.0
    530 0.9666666984558105
    540 1.0
    550 0.9333333969116211
    560 0.9000000357627869
    570 1.0
    580 0.9000000357627869
    590 0.9000000357627869
    600 0.8666667342185974
    610 0.8333333730697632
    620 0.9000000357627869
    630 0.9000000357627869
    640 0.8666667342185974
    650 0.9000000357627869
    660 0.9666666984558105
    670 0.9333333969116211
    680 0.8666667342185974
    690 0.9000000357627869
    700 0.8666667342185974
    710 0.9333333969116211
    720 0.9666666984558105
    730 0.9666666984558105
    740 0.9666666984558105
    750 0.9000000357627869
    760 0.9000000357627869
    770 0.9000000357627869
    780 0.9333333969116211
    790 0.9333333969116211
    800 0.9333333969116211
    810 0.8666667342185974
    820 0.9000000357627869
    830 0.9000000357627869
    avg acc: 0.9071942910873633
    test acc: 0.8886890964542361
    0 0.9333333969116211
    10 0.9333333969116211
    20 0.9666666984558105
    30 0.9333333969116211
    40 0.9333333969116211
    50 0.8666667342185974
    60 1.0
    70 0.8333333730697632
    80 0.9666666984558105
    90 0.9000000357627869
    100 0.9666666984558105
    110 0.9666666984558105
    120 0.9333333969116211
    130 0.9333333969116211
    140 0.9000000357627869
    150 0.9666666984558105
    160 0.8666667342185974
    170 0.9666666984558105
    180 0.9666666984558105
    190 0.9333333969116211
    200 0.9333333969116211
    210 0.8666667342185974
    220 0.9000000357627869
    230 0.8333333730697632
    240 0.9333333969116211
    250 0.8000000715255737
    260 0.8666667342185974
    270 0.9000000357627869
    280 0.9000000357627869
    290 0.9666666984558105
    300 0.9333333969116211
    310 0.9000000357627869
    320 0.9333333969116211
    330 0.9666666984558105
    340 0.9000000357627869
    350 1.0
    360 0.9666666984558105
    370 0.9333333969116211
    380 0.9333333969116211
    390 0.9666666984558105
    400 0.9666666984558105
    410 0.9666666984558105
    420 1.0
    430 0.9000000357627869
    440 1.0
    450 0.9000000357627869
    460 0.9333333969116211
    470 1.0
    480 0.9000000357627869
    490 0.9333333969116211
    500 0.9000000357627869
    510 0.9000000357627869
    520 0.9333333969116211
    530 0.9333333969116211
    540 0.9666666984558105
    550 0.9666666984558105
    560 0.9666666984558105
    570 0.9666666984558105
    580 0.8333333730697632
    590 0.9666666984558105
    600 0.9333333969116211
    610 0.9333333969116211
    620 0.9333333969116211
    630 1.0
    640 0.9000000357627869
    650 0.8666667342185974
    660 0.9333333969116211
    670 0.8666667342185974
    680 0.9666666984558105
    690 0.9333333969116211
    700 1.0
    710 0.9666666984558105
    720 0.9666666984558105
    730 0.9000000357627869
    740 0.9333333969116211
    750 0.9666666984558105
    760 1.0
    770 0.8666667342185974
    780 0.9000000357627869
    790 0.9333333969116211
    800 0.9666666984558105
    810 0.9000000357627869
    820 0.9666666984558105
    830 0.8000000715255737
    avg acc: 0.9266587171337302
    test acc: 0.8872902161068768
    0 0.9333333969116211
    10 1.0
    20 1.0
    30 0.9666666984558105
    40 0.9666666984558105
    50 1.0
    60 0.9333333969116211
    70 0.9666666984558105
    80 0.8666667342185974
    90 0.9666666984558105
    100 0.9333333969116211
    110 0.8666667342185974
    120 0.9333333969116211
    130 0.9000000357627869
    140 0.8333333730697632
    150 0.9666666984558105
    160 0.9666666984558105
    170 0.8666667342185974
    180 0.9666666984558105
    190 0.9666666984558105
    200 0.9333333969116211
    210 0.9333333969116211
    220 0.9666666984558105
    230 0.9666666984558105
    240 0.9000000357627869
    250 1.0
    260 0.9333333969116211
    270 0.9666666984558105
    280 0.9333333969116211
    290 0.9000000357627869
    300 1.0
    310 0.9333333969116211
    320 0.9666666984558105
    330 0.9666666984558105
    340 0.9333333969116211
    350 0.9333333969116211
    360 0.9333333969116211
    370 0.9333333969116211
    380 1.0
    390 1.0
    400 0.9333333969116211
    410 1.0
    420 0.9333333969116211
    430 0.9666666984558105
    440 0.9333333969116211
    450 0.9333333969116211
    460 0.9666666984558105
    470 0.8333333730697632
    480 1.0
    490 0.9333333969116211
    500 0.9666666984558105
    510 0.9000000357627869
    520 0.9000000357627869
    530 1.0
    540 0.9333333969116211
    550 0.9666666984558105
    560 0.9000000357627869
    570 0.9333333969116211
    580 0.9333333969116211
    590 0.9666666984558105
    600 0.8333333730697632
    610 0.9333333969116211
    620 0.8666667342185974
    630 0.9000000357627869
    640 0.9333333969116211
    650 0.9666666984558105
    660 0.9666666984558105
    670 0.9333333969116211
    680 0.9333333969116211
    690 0.9333333969116211
    700 0.9666666984558105
    710 0.9000000357627869
    720 0.9333333969116211
    730 1.0
    740 0.9666666984558105
    750 0.9333333969116211
    760 0.9666666984558105
    770 0.8333333730697632
    780 0.9666666984558105
    790 0.9000000357627869
    800 0.9000000357627869
    810 0.9000000357627869
    820 0.9666666984558105
    830 0.9666666984558105
    avg acc: 0.9356515197445163
    test acc: 0.890008042184569
    0 1.0
    10 1.0
    20 0.9000000357627869
    30 0.8666667342185974
    40 0.9000000357627869
    50 0.9333333969116211
    60 0.9000000357627869
    70 0.9666666984558105
    80 0.8666667342185974
    90 0.9000000357627869
    100 0.9333333969116211
    110 1.0
    120 0.9666666984558105
    130 0.9666666984558105
    140 1.0
    150 0.9333333969116211
    160 0.9333333969116211
    170 0.9333333969116211
    180 1.0
    190 0.9666666984558105
    200 0.9333333969116211
    210 1.0
    220 0.9666666984558105
    230 1.0
    240 0.9333333969116211
    250 0.8333333730697632
    260 0.9666666984558105
    270 0.9333333969116211
    280 0.9000000357627869
    290 1.0
    300 0.9666666984558105
    310 0.9333333969116211
    320 0.9000000357627869
    330 0.9000000357627869
    340 1.0
    350 0.9666666984558105
    360 1.0
    370 0.9666666984558105
    380 0.9000000357627869
    390 0.9666666984558105
    400 0.9666666984558105
    410 0.9333333969116211
    420 0.9000000357627869
    430 1.0
    440 0.9333333969116211
    450 0.9666666984558105
    460 0.9666666984558105
    470 1.0
    480 1.0
    490 0.9666666984558105
    500 1.0
    510 1.0
    520 1.0
    530 1.0
    540 0.8666667342185974
    550 1.0
    560 0.9333333969116211
    570 0.9333333969116211
    580 0.9666666984558105
    590 0.9666666984558105
    600 0.9333333969116211
    610 0.9000000357627869
    620 0.9333333969116211
    630 0.9666666984558105
    640 0.9666666984558105
    650 0.9333333969116211
    660 0.9333333969116211
    670 0.9000000357627869
    680 0.9333333969116211
    690 0.9000000357627869
    700 0.9333333969116211
    710 0.9666666984558105
    720 0.9666666984558105
    730 0.9333333969116211
    740 0.9333333969116211
    750 1.0
    760 0.9666666984558105
    770 0.9333333969116211
    780 0.9333333969116211
    790 0.9000000357627869
    800 1.0
    810 0.9000000357627869
    820 1.0
    830 0.9000000357627869
    avg acc: 0.9450040338136595
    test acc: 0.8848521674422624
    0 1.0
    10 1.0
    20 0.9666666984558105
    30 0.9666666984558105
    40 1.0
    50 1.0
    60 0.9666666984558105
    70 1.0
    80 0.9666666984558105
    100 0.9666666984558105
    110 0.9666666984558105
    120 0.9333333969116211
    130 0.9666666984558105
    140 0.9666666984558105
    150 1.0
    160 0.9666666984558105
    170 1.0
    180 1.0
    190 0.9666666984558105
    200 0.8666667342185974
    210 1.0
    220 0.8666667342185974
    230 0.9666666984558105
    240 0.9333333969116211
    250 0.8333333730697632
    260 0.9666666984558105
    270 0.9666666984558105
    280 0.9000000357627869
    290 0.9666666984558105
    300 0.9666666984558105
    310 0.9333333969116211
    320 1.0
    330 0.9666666984558105
    340 0.9666666984558105
    350 0.9333333969116211
    360 0.9000000357627869
    370 0.8666667342185974
    380 0.9333333969116211
    390 0.8333333730697632
    400 0.9666666984558105
    410 1.0
    420 0.9666666984558105
    430 0.9666666984558105
    440 1.0
    450 0.9666666984558105
    460 0.9333333969116211
    470 1.0
    480 0.9666666984558105
    490 1.0
    500 0.9666666984558105
    510 0.9333333969116211
    520 0.8666667342185974
    530 0.9666666984558105
    540 1.0
    550 1.0
    560 0.9333333969116211
    570 0.9333333969116211
    580 1.0
    590 0.9666666984558105
    600 0.9666666984558105
    610 0.9666666984558105
    620 0.9666666984558105
    630 0.9666666984558105
    640 0.9333333969116211
    650 0.9000000357627869
    660 0.9333333969116211
    670 1.0
    680 0.9333333969116211
    690 0.9666666984558105
    700 0.9333333969116211
    710 1.0
    720 0.9333333969116211
    730 1.0
    740 0.9666666984558105
    750 0.9666666984558105
    760 0.8666667342185974
    770 0.9000000357627869
    780 0.8000000715255737
    790 0.9666666984558105
    800 0.9666666984558105
    810 0.8666667342185974
    820 1.0
    830 0.9666666984558105
    avg acc: 0.9509592677334802
    test acc: 0.8718625588668621
    0 1.0
    10 1.0
    20 0.9666666984558105
    30 0.9333333969116211
    40 1.0
    50 0.9666666984558105
    60 0.9666666984558105
    70 0.9666666984558105
    80 1.0
    90 0.9333333969116211
    100 1.0
    110 0.9666666984558105
    120 0.9666666984558105
    130 0.9666666984558105
    140 0.9666666984558105
    150 0.9666666984558105
    160 0.9666666984558105
    170 1.0
    180 0.9666666984558105
    190 0.9000000357627869
    200 1.0
    210 1.0
    220 0.9333333969116211
    230 1.0
    240 0.9666666984558105
    250 1.0
    260 0.9666666984558105
    270 0.9666666984558105
    280 0.9333333969116211
    290 0.9333333969116211
    300 0.9666666984558105
    310 0.9666666984558105
    320 0.9666666984558105
    330 0.9333333969116211
    340 1.0
    350 0.9333333969116211
    360 0.9666666984558105
    370 0.9333333969116211
    380 0.9666666984558105
    390 0.9333333969116211
    400 0.9666666984558105
    410 0.9666666984558105
    420 0.9666666984558105
    430 0.9333333969116211
    440 0.9333333969116211
    450 0.9666666984558105
    460 1.0
    470 1.0
    480 0.9666666984558105
    490 0.9333333969116211
    500 0.9666666984558105
    510 0.9333333969116211
    520 0.9666666984558105
    530 0.9666666984558105
    540 1.0
    550 0.9666666984558105
    560 0.9333333969116211
    570 1.0
    580 0.9666666984558105
    590 0.9666666984558105
    600 1.0
    610 0.9000000357627869
    620 0.9333333969116211
    630 0.9333333969116211
    640 0.9333333969116211
    650 0.9666666984558105
    660 0.9000000357627869
    670 0.9000000357627869
    680 1.0
    690 0.9333333969116211
    700 0.9666666984558105
    710 0.8000000715255737
    720 0.9333333969116211
    730 0.8666667342185974
    740 0.9333333969116211
    750 0.9666666984558105
    760 1.0
    770 0.9333333969116211
    780 0.9000000357627869
    790 0.9666666984558105
    800 0.9333333969116211
    810 0.8666667342185974
    820 0.9000000357627869
    830 0.9666666984558105
    avg acc: 0.9605116213111283
    test acc: 0.8822142779827118
    0 0.9666666984558105
    10 0.9666666984558105
    20 1.0
    30 0.9666666984558105
    40 1.0
    50 0.9666666984558105
    60 1.0
    70 0.9000000357627869
    80 1.0
    90 0.9666666984558105
    100 0.9333333969116211
    110 1.0
    120 1.0
    130 0.9666666984558105
    140 0.9666666984558105
    150 1.0
    160 0.9666666984558105
    170 0.9333333969116211
    180 0.9666666984558105
    190 0.9333333969116211
    200 0.9666666984558105
    210 1.0
    220 0.9666666984558105
    230 1.0
    240 0.9666666984558105
    250 1.0
    260 0.9333333969116211
    270 0.9666666984558105
    280 0.9000000357627869
    290 1.0
    300 0.9333333969116211
    310 0.9666666984558105
    320 0.9666666984558105
    330 0.9333333969116211
    340 1.0
    350 0.9333333969116211
    360 0.9666666984558105
    370 1.0
    380 1.0
    390 0.9000000357627869
    400 1.0
    410 1.0
    420 1.0
    430 1.0
    440 1.0
    450 0.9666666984558105
    460 0.9000000357627869
    470 1.0
    480 1.0
    490 0.8666667342185974
    500 1.0
    510 1.0
    520 1.0
    530 0.9666666984558105
    540 0.9000000357627869
    550 1.0
    560 0.9333333969116211
    570 0.9666666984558105
    580 1.0
    590 0.9666666984558105
    600 0.9333333969116211
    610 0.9666666984558105
    620 0.9666666984558105
    630 1.0
    640 0.9000000357627869
    650 0.9666666984558105
    660 1.0
    670 0.9000000357627869
    680 0.9333333969116211
    690 1.0
    700 1.0
    710 1.0
    720 0.9666666984558105
    730 1.0
    740 1.0
    750 1.0
    760 1.0
    770 0.8666667342185974
    780 0.9666666984558105
    790 0.9333333969116211
    800 0.9666666984558105
    810 1.0
    820 1.0
    830 0.9666666984558105
    avg acc: 0.9653077817363419
    test acc: 0.8769784666222634
    0 1.0
    10 0.9666666984558105
    20 1.0
    30 0.9333333969116211
    40 1.0
    50 1.0
    60 1.0
    70 1.0
    80 0.9666666984558105
    90 0.9333333969116211
    100 0.9666666984558105
    110 0.9666666984558105
    120 1.0
    130 0.9666666984558105
    140 1.0
    150 1.0
    160 0.9666666984558105
    170 1.0
    180 0.9333333969116211
    190 1.0
    200 0.9666666984558105
    210 1.0
    220 0.8333333730697632
    230 1.0
    240 1.0
    250 0.9666666984558105
    260 0.9666666984558105
    270 0.9000000357627869
    280 0.9666666984558105
    290 0.9333333969116211
    300 0.9666666984558105
    310 0.9666666984558105
    320 0.9333333969116211
    330 1.0
    340 1.0
    350 0.9333333969116211
    360 0.9666666984558105
    370 0.9666666984558105
    380 0.9666666984558105
    390 1.0
    400 0.9333333969116211
    410 0.9333333969116211
    420 1.0
    430 0.9666666984558105
    440 0.9666666984558105
    450 0.9333333969116211
    460 1.0
    470 0.9666666984558105
    480 1.0
    490 1.0
    500 0.9333333969116211
    510 0.9666666984558105
    520 1.0
    530 0.9333333969116211
    540 0.9666666984558105
    550 0.9333333969116211
    560 0.9333333969116211
    570 0.9333333969116211
    580 1.0
    590 1.0
    600 0.9333333969116211
    610 0.9666666984558105
    620 1.0
    630 1.0
    640 1.0
    650 0.9666666984558105
    660 1.0
    670 1.0
    680 1.0
    690 0.9333333969116211
    700 1.0
    710 0.9333333969116211
    720 1.0
    730 1.0
    740 0.9666666984558105
    750 0.9000000357627869
    760 0.9000000357627869
    770 0.9333333969116211
    780 0.9666666984558105
    790 1.0
    800 1.0
    810 0.9666666984558105
    820 0.9666666984558105
    830 1.0
    avg acc: 0.9697442299885144
    test acc: 0.8815348212667506
    0 0.9666666984558105
    10 0.9666666984558105
    20 0.9666666984558105
    30 0.9666666984558105
    40 0.9666666984558105
    50 0.8666667342185974
    60 1.0
    70 1.0
    80 1.0
    90 1.0
    100 1.0
    110 0.9666666984558105
    120 1.0
    130 1.0
    140 1.0
    150 0.9666666984558105
    160 1.0
    170 0.9333333969116211
    180 1.0
    190 0.9000000357627869
    200 1.0
    210 0.8666667342185974
    220 1.0
    230 1.0
    240 1.0
    250 0.9000000357627869
    260 1.0
    270 1.0
    280 0.9666666984558105
    290 0.9666666984558105
    300 0.9666666984558105
    310 0.9666666984558105
    320 0.9666666984558105
    330 1.0
    340 1.0
    350 0.9333333969116211
    360 0.9666666984558105
    370 1.0
    380 0.9666666984558105
    390 1.0
    400 0.9666666984558105
    410 1.0
    420 1.0
    430 1.0
    440 1.0
    450 1.0
    460 1.0
    470 1.0
    480 0.9666666984558105
    490 1.0
    500 1.0
    510 1.0
    520 0.9666666984558105
    530 0.9666666984558105
    540 0.9666666984558105
    550 0.9000000357627869
    560 0.9000000357627869
    570 0.9666666984558105
    580 1.0
    590 0.9666666984558105
    600 1.0
    610 0.9666666984558105
    620 1.0
    630 0.9666666984558105
    640 0.9666666984558105
    650 1.0
    660 0.9666666984558105
    670 1.0
    680 0.9666666984558105
    690 0.9666666984558105
    700 0.9666666984558105
    710 1.0
    720 0.8666667342185974
    730 1.0
    740 0.9666666984558105
    750 0.9333333969116211
    760 1.0
    770 0.9666666984558105
    780 1.0
    790 0.9666666984558105
    800 1.0
    810 1.0
    820 1.0
    830 0.9333333969116211
    avg acc: 0.9726618941453435
    test acc: 0.8754996503714463
    

    6. 预测

    • 输出的预测:是('pos':1, 'neg':0)字符串的编号
    for batch in test_iterator:
        # batch_size个预测
        preds = rnn(batch.text).squeeze(1)
        preds = predice_test(preds)
        # print(preds)
    
        i = 0
        for text in batch.text:
            # 遍历一句话里的每个单词
            for word in text:
                print(TEXT.vocab.itos[word], end=' ')
        
            print('')
            # 输出3句话
            if i == 3:
                break
            i = i + 1
    
        i = 0
        for pred in preds:
            idx = int(pred.item())
            print(idx, LABEL.vocab.itos[idx])
            # 输出3个结果(标签)
            if i == 3:
                break
            i = i + 1
        break  
    
    Anyone <unk> Great A If <unk> Without The Brilliant This <unk> This If This Ten Absolutely For A This One Add a Just This I More What Brilliant Read <unk> 
    who Classic story great you hires a <unk> . movie it is you is minutes fantastic pure touching is of this mesmerizing love is hope suspenseful a and the <unk> 
    gives Waters , film 've a doubt mixed <unk> is with the like quite of ! <unk> movie a the little film the a this , script moving book <unk> 
    this ! great in ever psychopath , with along terrible all greatest <unk> possibly people Whatever vampire . good funniest gem that interplay great group more , performances , interpretation 
    1 pos
    1 pos
    1 pos
    1 pos
    

    可以写成

    import spacy
    nlp = spacy.load('en_core_web_md')
    
    def predict_sentiment(sentence):
        tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
    #     print(tokenized)   # ['This', 'film', 'is', 'terrible']
        indexed = [TEXT.vocab.stoi[t] for t in tokenized]
        tensor = torch.LongTensor(indexed).to(device)
        text = tensor.unsqueeze(0)  
        prediction = torch.sigmoid(rnn(tensor))
        return prediction.item()
    
    predict_sentiment("This film is great")  # 1.0
    

    导入模型,并测试

    model.load_state_dict(torch.load('wordavg-model.pt'))
    test_loss, test_acc = evaluate(rnn, test_iterator, criterion)
    print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
    
  • 相关阅读:
    次小生成树
    [bzoj5329] P4606 [SDOI2018]战略游戏
    CF487E Tourists
    P3225 [HNOI2012]矿场搭建
    CF #636 (Div. 3) 对应题号CF1343
    P3469 [POI2008]BLO-Blockade
    大假期集训模拟赛12
    大假期集训模拟赛11
    大假期集训模拟赛10
    小奇画画——BFS
  • 原文地址:https://www.cnblogs.com/douzujun/p/13374396.html
Copyright © 2011-2022 走看看