zoukankan      html  css  js  c++  java
  • Transformers 词汇表 | 二

    作者|huggingface
    编译|VK
    来源|Github

    词汇表每种模型都不同,但与其他模型相似。因此,大多数模型使用相同的输入,此处将在用法示例中进行详细说明。

    输入ID

    输入id通常是传递给模型作为输入的唯一必需参数。它们是标记索引,标记的数字表示构建将被模型用作输入的序列。

    每个tokenizer的工作方式不同,但基本机制保持不变。这是一个使用BERTtokenizer(WordPiecetokenizer)的示例:

    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    
    sequence = "A Titan RTX has 24GB of VRAM"
    

    tokenizer负责将序列拆分为tokenizer词汇表中可用的标记。

    #继续上一个脚本
    tokenized_sequence = tokenizer.tokenize(sequence)
    assert tokenized_sequence == ['A', 'Titan', 'R', '##T', '##X', 'has', '24', '##GB', 'of', 'V', '##RA', '##M']
    

    然后可以将这些标记转换为模型可以理解的ID。有几种方法可以使用,推荐使用的是encodeencode_plus,它们实现了最佳性能。

    #继续上一个脚本
    encode_sequence = tokenizer.encode(sequence)
    assert encoded_sequence == [101, 138, 18696, 155, 1942, 3190, 1144, 1572, 13745, 1104, 159, 9664, 2107, 102]
    

    encodeencode_plus方法自动添加“特殊标记”,这是模型使用的特殊ID。

    注意力掩码

    注意掩码是将序列批处理在一起时使用的可选参数。此参数向模型指示应该注意哪些标记,哪些不应该注意。

    例如,考虑以下两个序列:

    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    
    sequence_a = "This is a short sequence."
    sequence_b = "This is a rather long sequence. It is at least longer than the sequence A."
    
    encoded_sequence_a = tokenizer.encode(sequence_a)
    assert len(encoded_sequence_a) == 8
    
    encoded_sequence_b = tokenizer.encode(sequence_b)
    assert len(encoded_sequence_b) == 19
    

    这两个序列的长度不同,因此不能按原样放在同一张量中。需要将第一个序列填充到第二个序列的长度,或者将第二个序列截短到第一个序列的长度。

    在第一种情况下,ID列表将通过填充索引扩展:

    #继续上一个脚本
    padded_sequence_a = tokenizer.encode(sequence_a, max_length=19, pad_to_max_length=True)
    
    assert padded_sequence_a == [101, 1188, 1110, 170, 1603, 4954,  119, 102,    0,    0,    0,    0,    0,    0,    0,    0,   0,   0,   0]
    assert encoded_sequence_b == [101, 1188, 1110, 170, 1897, 1263, 4954, 119, 1135, 1110, 1120, 1655, 2039, 1190, 1103, 4954, 138, 119, 102]
    

    然后可以将它们转换为PyTorch或TensorFlow中的张量。注意掩码是一个二进制张量,指示填充索引的位置,以便模型不会注意它们。对于BertTokenizer,1表示应注意的值,而0表示填充值。

    方法encode_plus()可用于直接获取注意力掩码:

    #继续上一个脚本
    sequence_a_dict = tokenizer.encode_plus(sequence_a, max_length=19, pad_to_max_length=True)
    
    assert sequence_a_dict['input_ids'] == [101, 1188, 1110, 170, 1603, 4954, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    assert sequence_a_dict['attention_mask'] == [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    

    标记类型ID

    一些模型的目的是进行序列分类或问题解答。这些要求将两个不同的序列编码在相同的输入ID中。它们通常由特殊标记分隔,例如分类器标记和分隔符标记。例如,BERT模型按如下方式构建其两个序列输入:

    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    
    # [CLS] SEQ_A [SEP] SEQ_B [SEP]
    
    sequence_a = "HuggingFace is based in NYC"
    sequence_b = "Where is HuggingFace based?"
    
    encoded_sequence = tokenizer.encode(sequence_a, sequence_b)
    assert tokenizer.decode(encoded_sequence) == "[CLS] HuggingFace is based in NYC [SEP] Where is HuggingFace based? [SEP]"
    

    对于某些模型而言,这足以了解一个序列在何处终止以及另一序列在何处开始。但是,其他模型(例如BERT)具有附加机制,即段ID。标记类型ID是一个二进制掩码,用于标识模型中的不同序列。

    我们可以利用encode_plus()为我们输出标记类型ID:

    #继续上一个脚本
    encoded_dict = tokenizer.encode_plus(sequence_a, sequence_b)
    
    assert encoded_dict['input_ids'] == [101, 20164, 10932, 2271, 7954, 1110, 1359, 1107, 17520, 102, 2777, 1110, 20164, 10932, 2271, 7954, 1359, 136, 102]
    assert encoded_dict['token_type_ids'] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    

    第一个序列,即用于问题的“上下文”,其所有标记均由0表示,而问题的所有标记均由1表示。某些模型(例如XLNetModel)使用由2表示的附加标记。

    位置ID

    模型使用位置ID来识别哪个标记在哪个位置。与将每个标记的位置嵌入其中的RNN相反,转换器不知道每个标记的位置。为此创建了位置ID。

    它们是可选参数。如果没有位置ID传递给模型,则它们将自动创建为绝对位置嵌入。

    [0, config.max_position_embeddings - 1]范围内选择绝对位置嵌入。一些模型使用其他类型的位置嵌入,例如正弦位置嵌入或相对位置嵌入。

    欢迎关注磐创博客资源汇总站:
    http://docs.panchuang.net/

    欢迎关注PyTorch官方中文教程站:
    http://pytorch.panchuang.net/

    OpenCV中文官方文档:
    http://woshicver.com/

  • 相关阅读:
    Python3.7 练习题(-) 如何使用Python生成200个优惠卷(激活码)
    Could not find a version that satisfies the requirement PIL
    python中如何对待易过期的cookies
    python代码在linux服务器一般的开头
    mysql innodb引擎 一次线上死锁分析排查步骤
    centos 6.5 gogs迁移外部仓库报错
    mysql 存儲emjoy表情是報錯Incorrect string value:
    python开发技巧---列表、字典、集合值的过滤
    zabbix学习-如何部署一个agent客户端
    zabbix学习-zabbix安装
  • 原文地址:https://www.cnblogs.com/panchuangai/p/12567845.html
Copyright © 2011-2022 走看看