pytorch 的LSTM batch_first=True 和 False的性能略有区别,不过区别不大。
下面这篇文章试验结论是batch_first= True要比batch_first = False更快。但是我自己跑结论却是相反,batch_first = False更快。
运行多次的结果:
2.3414649963378906 2.0364670753479004
2.188401699066162 2.2298429012298584
2.25323224067688 2.202291488647461
2.2564923763275146 2.1362855434417725
2.3355021476745605 2.1648573875427246
2.367983818054199 2.4390225410461426
2.3107049465179443 2.3457281589508057
2.261659622192383 2.1843318939208984
2.2949719429016113 2.1492083072662354
看到大部分情况后者更快(batch_first = False更快)。
下面是知乎上一篇文章的结果:
https://zhuanlan.zhihu.com/p/50484629?from_voters_page=true
经过实测,发现batch_first= True要比batch_first = False更快(不知道为啥pytorch要默认是batchfirst= False,同时网上很多地方都在说batch_first= False性能更好)
x_1 = torch.randn(100,200,512)
x_2 = x_1.transpose(0,1)
model_1 = torch.nn.LSTM(batch_first=True,hidden_size=1024,input_size=512)
model_2 = torch.nn.LSTM(batch_first=False,hidden_size=1024,input_size=512)
start_time_1 = time.time()
result_1 = model_1(x_1)
end_time_1 = time.time()
result_2 = model_2(x_2)
end_time_2 = time.time()
print(end_time_1 - start_time_1,end_time_2 - end_time_1)