关于该类:
torch.nn.Linear(in_features, out_features, bias=True)
可以对输入数据进行线性变换:
$y = x A^T + b$
in_features: 输入数据的大小。
out_features: 输出数据的大小。
bias: 是否添加一个可学习的 bias,即上式中的 $b$。
该线性变换,只对输入的 tensor 的最后一维进行:
例如我们有一个Linear层如下:
m = nn.Linear(20, 30)
示例1:
input = torch.randn(2, 5, 8, 20) output = m(input) print(output.size())
结果:
torch.Size([2, 5, 8, 30])
示例2:
input = torch.randn(20) output = m(input) print(output.size())
结果:
torch.Size([30])