from torchvision.models import resnet18 from utils import Flatten trained_model = resnet18(pretrained = True) model = nn.Sequential(*list(trained_model.children())[:-1]), #[b,512,1,1] Flatten(), #[b,512,1,1] ===>[b,512*1*1] nn.Linear(512,5) ) x = torch.randn(2,3,224,224) print(model(x).shape)
pytorch使用torch.nn.Sequential快速搭建神经网络
Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。保留第一维的维度,其他相乘为一个数. 链接
pytorch函数之nn.Linear
class torch.nn.Linear(in_features,out_features,bias = True )[来源]
对传入数据应用线性变换:y = A x+ b
参数:
in_features - 每个输入样本的大小
out_features - 每个输出样本的大小
bias - 如果设置为False,则图层不会学习附加偏差。默认值:True
代码:
m = nn.Linear(20, 30) input = autograd.Variable(torch.randn(128, 20)) output = m(input) print(output.size())
输出:
torch.Size([128, 30])
output.size()=矩阵size(128,20)*矩阵size(20,30)=(128,30)