zoukankan      html  css  js  c++  java
  • Transformers 中使用 TorchScript | 四

    作者|huggingface
    编译|VK
    来源|Github

    注意:这是我们使用TorchScript进行实验的开始,我们仍在探索可变输入大小模型的功能。它是我们关注的焦点,我们将在即将发布的版本中加深我们的分析,提供更多代码示例,更灵活的实现以及将基于python的代码与已编译的TorchScript进行基准测试的比较。

    根据Pytorch的文档:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法”。Pytorch的两个模块JIT和TRACE允许开发人员导出他们的模型,这些模型可以在其他程序中重用,例如面向效率的C++程序。

    我们提供了一个接口,该接口允许将transformers模型导出到TorchScript,以便他们可在与基于Pytorch的python程序不同的环境中重用。在这里,我们解释如何使用我们的模型,以便可以导出它们,以及将这些模型与TorchScript一起使用时要注意的事项。

    导出模型需要两件事:

    • 虚拟化输入以执行模型正向传播。
    • 需要使用torchscript标志实例化该模型。

    这些必要性意味着开发人员应注意几件事。这些在下面详细说明。

    含义

    TorchScript标志和解绑权重

    该标志是必需的,因为该存储库中的大多数语言模型都在它们的Embedding层及其Decoding层具有绑定权重关系。TorchScript不允许导出绑定权重的模型,因此,有必要事先解绑权重。

    这意味着以torchscript标志实例化的模型使得Embedding层和Decoding层分开,这意味着不应该对他们进行同时训练,导致意外的结果。

    对于没有语言模型头(Language Model head)的模型,情况并非如此,因为那些模型没有绑定权重。这些型号可以在没有torchscript标志的情况下安全地导出。

    虚拟(dummy)输入和标准长度

    虚拟输入用于进行模型前向传播。当输入值在各层中传播时,Pytorch跟踪在每个张量上执行的不同操作。然后使用这些记录的操作创建模型的“迹"。

    迹是相对于输入的尺寸创建的。因此,它受到虚拟输入尺寸的限制,并且不适用于任何其他序列长度或批次大小。尝试使用其他尺寸时,会出现如下错误,如:

    The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2

    因此,建议使用至少与最大输入大小相同的虚拟输入大小来跟踪模型。在推理期间对于模型的输入,可以执行填充来填充缺少的值。作为模型
    将以较大的输入大小来进行跟踪,但是,不同矩阵的尺寸也将很大,从而导致更多的计算。

    建议注意每个输入上完成的操作总数,并密切关注各种序列长度对应性能的变化。

    在Python中使用TorchScript

    以下是使用Python保存,加载模型以及如何使用"迹"进行推理的示例。

    保存模型

    该代码段显示了如何使用TorchScript导出BertModel。在这里实例化BertModel,根据BertConfig类,然后以文件名traced_bert.pt保存到磁盘

    from transformers import BertModel, BertTokenizer, BertConfig
    import torch
    
    enc = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # 标记输入文本
    text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
    tokenized_text = enc.tokenize(text)
    
    # 输入标记之一进行掩码
    masked_index = 8
    tokenized_text[masked_index] = '[MASK]'
    indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
    segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
    
    # 创建虚拟输入
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    dummy_input = [tokens_tensor, segments_tensors]
    
    # 使用torchscript标志初始化模型
    # 标志被设置为True,即使没有必要,因为该型号没有LM Head。
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)
    
    # 实例化模型
    model = BertModel(config)
    
    # 模型设置为评估模式
    model.eval()
    
    # 如果您要​​使用from_pretrained实例化模型,则还可以设置TorchScript标志
    model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
    
    # 创建迹
    traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
    torch.jit.save(traced_model, "traced_bert.pt")
    

    载入模型

    该代码段显示了如何加载以前以名称traced_bert.pt保存到磁盘的BertModel
    我们重新使用之前初始化的dummy_input

    loaded_model = torch.jit.load("traced_model.pt")
    loaded_model.eval()
    
    all_encoder_layers, pooled_output = loaded_model(dummy_input)
    

    使用跟踪模型进行推理

    使用跟踪模型进行推理就像使用其__call__ 方法一样简单:

    traced_model(tokens_tensor, segments_tensors)
    

    原文链接:https://huggingface.co/transformers/torchscript.html

    欢迎关注磐创AI博客站:
    http://panchuang.net/

    OpenCV中文官方文档:
    http://woshicver.com/

    欢迎关注磐创博客资源汇总站:
    http://docs.panchuang.net/

  • 相关阅读:
    TypesScript+Webpack
    TypeScript 类型
    git操作
    kafka
    java: cannot find symbol symbol: variable log
    Angular结构型指令,模块和样式
    Angular 自定义拖拽指令
    Angular changeDetction
    Angular 依赖注入
    RXJS Observable的冷,热和Subject
  • 原文地址:https://www.cnblogs.com/panchuangai/p/12567841.html
Copyright © 2011-2022 走看看