相关工具:
1、torchsummary:打印torch模型每层形状
sksq96/pytorch-summary: Model summary in PyTorch similar to `model.summary()` in Keras (github.com)
How to install
pip install torchsummary
How to Use
from torchsummary import summary summary(model, (1, 28, 28))
2、THOP: 统计 PyTorch 模型的 FLOPs 和参数量
Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model. (github.com)
How to install
pip install thop (now continously intergrated on Github actions)
OR
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
How to use
-
Basic usage
from torchvision.models import resnet50 from thop import profile model = resnet50() input = torch.randn(1, 3, 224, 224) macs, params = profile(model, inputs=(input, ))
-
Define the rule for 3rd party module.
class YourModule(nn.Module): # your definition def count_your_model(model, x, y): # your rule here input = torch.randn(1, 3, 224, 224) macs, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model})
-
Improve the output readability
Call
thop.clever_format
to give a better format of the output.from thop import clever_format macs, params = clever_format([macs, params], "%.3f")
3、Flops counter for convolutional networks in pytorch framework
How to install
From PyPI:
pip install ptflops
From this repository:
pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
How to use
import torchvision.models as models import torch from ptflops import get_model_complexity_info with torch.cuda.device(0): net = models.densenet161() macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True, verbose=True) print('{:<30} {:<8}'.format('Computational complexity: ', macs)) print('{:<30} {:<8}'.format('Number of parameters: ', params))