zoukankan      html  css  js  c++  java
  • Resnet-34框架

    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
     
    class ResidualBlock(nn.Module):
        
        '''
        实现子module: Residual Block
        '''
        
        def __init__(self,inchannel,outchannel,stride=1,shortcut=None):
            
            super(ResidualBlock,self).__init__()
            
            self.left=nn.Sequential(
                nn.Conv2d(inchannel,outchannel,3,stride,1,bias=False),
                nn.BatchNorm2d(outchannel),
                nn.ReLU(inplace=True),
                nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),
                nn.BatchNorm2d(outchannel)
            )
            self.right=shortcut
        
        def forward(self,x):
            
            out=self.left(x)
            residual=x if self.right is None else self.right(x)
            out+=residual
            return F.relu(out)
        
    class ResNet(nn.Module):
        
        '''
        实现主module:ResNet34
        ResNet34 包含多个layer,每个layer又包含多个residual block
        用子module来实现residual block,用_make_layer函数来实现layer
        '''
        
        def __init__(self,num_classes=1000):
            
            super(ResNet,self).__init__()
            
            # 前几层图像转换
            self.pre=nn.Sequential(
                nn.Conv2d(3,64,7,2,3,bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3,2,1)
            )
            
            # 重复的layer,分别有3,4,6,3个residual block
            self.layer1=self._make_layer(64,64,3)
            self.layer2=self._make_layer(64,128,4,stride=2)
            self.layer3=self._make_layer(128,256,6,stride=2)
            self.layer4=self._make_layer(256,512,3,stride=2)
            
            #分类用的全连接
            self.fc=nn.Linear(512,num_classes)
        
        def _make_layer(self,inchannel,outchannel,bloch_num,stride=1):
            
            '''
            构建layer,包含多个residual block
            '''
            shortcut=nn.Sequential(
                nn.Conv2d(inchannel,outchannel,1,stride,bias=False),
                nn.BatchNorm2d(outchannel)
            )
            layers=[]
            layers.append(ResidualBlock(inchannel,outchannel,stride,shortcut))
            for i in range(1,bloch_num):
                layers.append(ResidualBlock(outchannel,outchannel))
            return nn.Sequential(*layers)
        
        def forward(self,x):
            
            x=self.pre(x)
            
            x=self.layer1(x)
            x=self.layer2(x)
            x=self.layer3(x)
            x=self.layer4(x)
            
            x=F.avg_pool2d(x,7)
            x=x.view(x.size(0),-1)
            return self.fc(x)
    
    if __name__ == '__main__':
        model=ResNet()
        # input=t.autograd.Variable(t.randn(1,3,224,224))
        input=t.autograd.Variable(t.randn(1,8,4,4))
        o=model(input)
        print(o)
    
    
    如果有一天我们淹没在茫茫人海中庸碌一生,那一定是我们没有努力活得丰盛
  • 相关阅读:
    每个部门都有自己的游戏规则
    ssh作为代理,反向登录没有固定公网ip的局域网内的某远程服务器
    x11vnc 作为远程桌面服务器时vnc客户端键盘无法长按连续输入字符
    vim 编译使用ycm启动问题 fixed
    ubuntu设置普通用户也能执行docker命令
    git常见使用
    切图的必要步骤
    css居中
    清除浮动
    Spring-AOP(2)
  • 原文地址:https://www.cnblogs.com/yeran/p/10577714.html
Copyright © 2011-2022 走看看