class BPNet(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2,
n_hidden_3, n_hidden_4, n_hidden_5, out_dim):
super(BPNet, self).__init__()
self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1))
self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, n_hidden_3), nn.BatchNorm1d(n_hidden_3), nn.ReLU(True))
self.layer4 = nn.Sequential(nn.Linear(n_hidden_3, n_hidden_4), nn.BatchNorm1d(n_hidden_4), nn.ReLU(True), nn.Dropout(0.1))
self.layer5 = nn.Sequential(nn.Linear(n_hidden_4, n_hidden_5), nn.BatchNorm1d(n_hidden_5))
self.layer6 = nn.Sequential(nn.Linear(n_hidden_5, out_dim))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
net = BPNet(in_dim=5, n_hidden_1=20, n_hidden_2=250, n_hidden_3=500, n_hidden_4=250, n_hidden_5=50, out_dim=2) # 实例化网络
简洁写法
cfg = {
'1': [20, 200, 500, 200, 50],
}
class BPNet(nn.Module):
def __init__(self, name):
super(BPNet, self).__init__()
self.features = self._make_layers(cfg[name])
self.classifier = nn.Sequential(
nn.Linear(cfg[name][-1], 2)
)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, cfg):
layers = []
in_dim = 5
for x in cfg:
layers += [nn.Linear(in_dim, x),
nn.BatchNorm1d(x),
nn.ReLU(inplace=True)]
in_dim = x
return nn.Sequential(*layers)
net = BPNet('1')