zoukankan      html  css  js  c++  java
  • Pytorch自定义创建BP神经网络

    class BPNet(nn.Module):
        def __init__(self, in_dim, n_hidden_1, n_hidden_2,
                     n_hidden_3, n_hidden_4, n_hidden_5, out_dim):
            super(BPNet, self).__init__()
            self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1))
            self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
            self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, n_hidden_3), nn.BatchNorm1d(n_hidden_3), nn.ReLU(True))
            self.layer4 = nn.Sequential(nn.Linear(n_hidden_3, n_hidden_4),  nn.BatchNorm1d(n_hidden_4), nn.ReLU(True), nn.Dropout(0.1))
            self.layer5 = nn.Sequential(nn.Linear(n_hidden_4, n_hidden_5), nn.BatchNorm1d(n_hidden_5))
            self.layer6 = nn.Sequential(nn.Linear(n_hidden_5, out_dim))
    
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.layer5(x)
            x = self.layer6(x)
            return x
    net = BPNet(in_dim=5, n_hidden_1=20, n_hidden_2=250, n_hidden_3=500, n_hidden_4=250, n_hidden_5=50, out_dim=2)  # 实例化网络


    简洁写法
    cfg = {
        '1': [20, 200, 500, 200, 50],
    
    }
    
    class BPNet(nn.Module):
        def __init__(self, name):
            super(BPNet, self).__init__()
            self.features = self._make_layers(cfg[name])
            self.classifier = nn.Sequential(
                nn.Linear(cfg[name][-1], 2)
            )
    
        def forward(self, x):
            out = self.features(x)
            out = out.view(out.size(0), -1)
            out = self.classifier(out)
            return out
    
        def _make_layers(self, cfg):
            layers = []
            in_dim = 5
            for x in cfg:
                layers += [nn.Linear(in_dim, x),
                           nn.BatchNorm1d(x),
                           nn.ReLU(inplace=True)]
                in_dim = x
            return nn.Sequential(*layers)
    net = BPNet('1')
  • 相关阅读:
    2015第14周四
    2015第14周三
    2015第14周二
    2015第14周一
    2015第13周日
    2015第13周六
    2015第13周五
    2015第13周四
    2015第13周三
    2015第13周二
  • 原文地址:https://www.cnblogs.com/ymzm204/p/12289057.html
Copyright © 2011-2022 走看看