zoukankan      html  css  js  c++  java
  • pytorch 的LSTM batch_first=True 和 False的性能对比

    pytorch 的LSTM batch_first=True 和 False的性能略有区别,不过区别不大。

    下面这篇文章试验结论是batch_first= True要比batch_first = False更快。但是我自己跑结论却是相反,batch_first = False更快。

    运行多次的结果:

    2.3414649963378906    2.0364670753479004

    2.188401699066162     2.2298429012298584

    2.25323224067688       2.202291488647461

    2.2564923763275146   2.1362855434417725

    2.3355021476745605   2.1648573875427246

    2.367983818054199     2.4390225410461426

    2.3107049465179443   2.3457281589508057

    2.261659622192383     2.1843318939208984

    2.2949719429016113   2.1492083072662354

    看到大部分情况后者更快(batch_first = False更快)。

    下面是知乎上一篇文章的结果:

    https://zhuanlan.zhihu.com/p/50484629?from_voters_page=true

    经过实测,发现batch_first= True要比batch_first = False更快(不知道为啥pytorch要默认是batchfirst= False,同时网上很多地方都在说batch_first= False性能更好)

    x_1 = torch.randn(100,200,512)
    x_2 = x_1.transpose(0,1)

    model_1 = torch.nn.LSTM(batch_first=True,hidden_size=1024,input_size=512)
    model_2 = torch.nn.LSTM(batch_first=False,hidden_size=1024,input_size=512)

    start_time_1 = time.time()


    result_1 = model_1(x_1)
    end_time_1 = time.time()

    result_2 = model_2(x_2)
    end_time_2 = time.time()

    print(end_time_1 - start_time_1,end_time_2 - end_time_1)

  • 相关阅读:
    有趣的放大镜
    特效代码
    向数据库添加学生信息。存放在REQUEST对象里
    机房servlet过滤器
    冒泡排序法
    验证码 随机生成器 详解
    生成器 种子
    生日
    在字符串里寻找某字符出现的个数
    课堂随笔
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/13376281.html
Copyright © 2011-2022 走看看