from thop import profile
class Test(nn.Module):
def __init__(self, input_size, output_szie):
super(Test, self).__init__()
self.out = nn.Linear(input_size, output_szie)
def forward(self, x):
output = self.out(x)
return output
t = Test(10, 2)
x = torch.randn(4, 10)
profile(t, (x,), verbose=False) # (80.0, 22.0): 10*2 + 2 = 22.0
# total_flops += flops
# model_params_num += params