zoukankan      html  css  js  c++  java
  • simpletransformers-可以简单快速搭建Transformer的库

    项目地址:https://hub.fastgit.org/ThilinaRajapakse/simpletransformers

    1. 创建虚拟环境,注意项目README中没写要python3.7,但是不是这个版本会报错
    conda create -n st python==3.7
    
    1. 安装GPU版本的Pytorch
    pip3 install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
    
    1. 依次安装缺少的包
    pip install pandas
    pip install tqdm
    pip install simpletransformers
    
    1. 运行Demo
    import logging
    
    import pandas as pd
    from simpletransformers.seq2seq import (
        Seq2SeqModel,
        Seq2SeqArgs,
    )
    
    
    logging.basicConfig(level=logging.INFO)
    transformers_logger = logging.getLogger("transformers")
    transformers_logger.setLevel(logging.WARNING)
    
    train_data = [
        [
            "Perseus “Percy” Jackson is the main protagonist and the narrator of the Percy Jackson and the Olympians series.",
            "Percy is the protagonist of Percy Jackson and the Olympians",
        ],
        [
            "Annabeth Chase is one of the main protagonists in Percy Jackson and the Olympians.",
            "Annabeth is a protagonist in Percy Jackson and the Olympians.",
        ],
    ]
    
    train_df = pd.DataFrame(
        train_data, columns=["input_text", "target_text"]
    )
    
    eval_data = [
        [
            "Grover Underwood is a satyr and the Lord of the Wild. He is the satyr who found the demigods Thalia Grace, Nico and Bianca di Angelo, Percy Jackson, Annabeth Chase, and Luke Castellan.",
            "Grover is a satyr who found many important demigods.",
        ],
        [
            "Thalia Grace is the daughter of Zeus, sister of Jason Grace. After several years as a pine tree on Half-Blood Hill, she got a new job leading the Hunters of Artemis.",
            "Thalia is the daughter of Zeus and leader of the Hunters of Artemis.",
        ],
    ]
    
    eval_df = pd.DataFrame(
        eval_data, columns=["input_text", "target_text"]
    )
    
    model_args = Seq2SeqArgs()
    model_args.num_train_epochs = 10
    model_args.no_save = True
    model_args.evaluate_generated_text = True
    model_args.evaluate_during_training = True
    model_args.evaluate_during_training_verbose = True
    
    # Initialize model
    model = Seq2SeqModel(
        encoder_decoder_type="bart",
        encoder_decoder_name="facebook/bart-large",
        args=model_args,
        use_cuda=True,
    )
    
    
    def count_matches(labels, preds):
        print(labels)
        print(preds)
        return sum(
            [
                1 if label == pred else 0
                for label, pred in zip(labels, preds)
            ]
        )
    
    
    # Train the model
    model.train_model(
        train_df, eval_data=eval_df, matches=count_matches
    )
    
    # # Evaluate the model
    results = model.eval_model(eval_df)
    
    # Use the model for prediction
    print(
        model.predict(
            [
                "Tyson is a Cyclops, a son of Poseidon, and Percy Jackson’s half brother. He is the current general of the Cyclopes army."
            ]
        )
    )
    
    1. 效果
  • 相关阅读:
    前端mvc mvp mvvm 架构介绍(vue重构项目一)
    SPA页面缓存再优化二
    消除浏览器对input输入框的自动填充
    单页面系统的一些性能优化
    城市联动组件插件思想分析
    前端性能优化点总结
    ui-router 1.0以上的 $stateChangeStart
    (转) view视图的放大、缩小、旋转
    (转)代码中实现button
    objective-c 强弱引用、properties的学习
  • 原文地址:https://www.cnblogs.com/mengxiaoleng/p/14562654.html
Copyright © 2011-2022 走看看