1 torch.nn.Flatten(start_dim=1, end_dim=-1)
Parameters
1 start_dim – first dim to flatten (default = 1). 2 end_dim – last dim to flatten (default = -1).
Shape
Examples
[32, 64, 56, 56] ——> [32, 200704]
64*56*56 = 200704
1 import torch 2 import torch.nn as nn 3 4 input = torch.randn(32, 64, 56, 56) 5 flatten = nn.Flatten() 6 print(flatten(input).size())