1 torch.nn.Linear(in_features, out_features, bias=True)
- 用于设置网络中的全连接层。
- 输入与输出均是二维张量。
- 输入与输出形状是 [batch_size, size]。
- 之前一般使用view或nn.Flatten()将4维张量变为2维张量。
Parameters
1 in_features – size of each input sample 2 out_features – size of each output sample 3 bias – If set to False, the layer will not learn an additive bias. Default: True
Shape
Applies a linear transformation to the incoming data: y = xA^T + b
Examples
[32, 512] ——> [32, 128]
1 import torch 2 import torch.nn as nn 3 4 input = torch.randn(32, 512) 5 linear = nn.Linear(512, 128) 6 print(linear(input).size())
[32, 128] ——> [32, 32]
1 import torch 2 import torch.nn as nn 3 4 input = torch.randn(32, 128) 5 linear = nn.Linear(128, 32) 6 print(linear(input).size())