zoukankan      html  css  js  c++  java
  • Tensorflow2 实现ResNets残差网络

    //20201018 update

    写在前面:

    前几天上完了NG的卷积神经网络第二章,并完成了相应的作业,在这里总结一下,作业是用Tensorflow2实现ResNet残差网络,本文主要说一下残差网络的架构以及实现方法(本人初学者,如若有写的不对的地方还请大家指出/拜托/拜托)

    1.ResNets残差网络简介

      首先,非常深的神经网络是很难训练的,因为存在梯度消失和梯度爆炸的问题。ResNets是由残差块(Residual block)构建的

      关于残差块(残差网络的核心):简单的来说就是在原本前向传播的进程中插入一个shortcut(捷径),捷径会传递原本需要通过前向传播的参数的“副本”,然后在捷径的尽头和前向传播的“源本”叠加(函数表达式抽象为H(x) = F(x) + x);原理:简单来说,就是在原本前向传播函数上加一个x参数,这样求导的时候,无论如何都会存在一个1,就可以解决深层网络多次迭代后梯度趋近零而导致梯度下降缓慢的问题

    2.Tensorflow2 实现代码(IDE使用PyCharm)

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers,Sequential
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    '''
    继承Layer基本属性自定义类
    '''
    class BasicBlock(layers.Layer):
        def __init__(self,filter_num,stride = 1):
            super(BasicBlock,self).__init__()
    
            self.conv1 = layers.Conv2D(filter_num,(3,3),strides = stride,padding = 'same')
            self.bn1 = layers.BatchNormalization()
            self.ac1 = layers.Activation('relu')
    
            self.conv2 = layers.Conv2D(filter_num,(3,3),strides = 1,padding = 'same')
            self.bn2 = layers.BatchNormalization()
    
            # 控制通过捷径的参数和通过普通道路的参数尺寸一样
            if stride != 1:
                self.downsample = Sequential()
                self.downsample.add(layers.Conv2D(filter_num,(1,1),strides = stride))
            else:
                self.downsample = lambda x:x
    
            self.ac2 = layers.Activation('relu')
    
        def call(self,inputs,training=None):
            out = self.conv1(inputs)
            out = self.bn1(out)
            out = self.ac1(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            # 如果输入中的步长为1,则identity = inputs,否则,需要经过一层卷积网络调整size
            identity = self.downsample(inputs)
    
            output = self.ac2(identity+out)
            return output
    
    '''
    继承Model基本属性创造基本模型
    '''
    class ResNet(keras.Model):
        def __init__(self,layer_dims,num_classes=6):
            super(ResNet,self).__init__()
    
            self.prev = Sequential([
                layers.Conv2D(64,(3,3),strides=(1,1)),
                layers.BatchNormalization(),
                layers.Activation('relu'),
                layers.MaxPooling2D(pool_size=(2,2),strides=(1,1))
            ])
    
            self.layer1 = self.build_resblock(64,layer_dims[0])
            self.layer2 = self.build_resblock(128,layer_dims[1],stride = 2)
            self.layer3 = self.build_resblock(256,layer_dims[2],stride = 2)
            self.layer4 = self.build_resblock(512,layer_dims[3],stride = 2)
    
            self.avgpool = layers.GlobalAveragePooling2D()
            self.fc = layers.Dense(num_classes)
    
        def call(self,inputs,training=None):
            x = self.prev(inputs)# initialize data
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
    
            x = self.avgpool(x)
    
            output = self.fc(x)
    
            return output# 到此为止整个残差网络构建完毕
    
        def build_resblock(selfself,filter_num,blocks,stride=1):
            res_blocks = Sequential()
            res_blocks.add(BasicBlock(filter_num,stride))
    
            for i in range(1,blocks):
                res_blocks.add(BasicBlock(filter_num,stride = 1))
    
            return res_blocks
    
    def resnet18():
        return ResNet([2,2,2,2])
    
    model = resnet18()
    model.build(input_shape=(64,32,32,3))
    print(model.summary())

      基本流程简介:

      - 首先创建残差块类(需继承layer类)——也就是自定义layer模块,这里需要注意的就是,如果在过程中参数矩阵的规格有变化的话,会导致走捷径的x和原本前向传播的x规格不匹配,所以在规格变化时需要在捷径中加上Conv层来控制size,最后合并的时候,根据残差块模型累加layer就可

      - 然后创建残差网络模块(需继承model类)——也就是自定义Model模块,定义好之后,在调用函数中按照传入参数将残差块累加在一起就好

    这里只输出了summary,没有代入具体环境

    summary如下:

    以上

    希望对大家有所帮助

  • 相关阅读:
    ld: cannot find lXXX" 如lpthread lgomp
    Glib交叉编译:g__cancellable_lock undeclared!&HEADER/C_IN undeclared!&undefined reference to "localeconv"
    Android_清除/更新Bundle中的数据(不finish() Activity的情况下)
    读Kernel感悟Linux内核启动从hello world说起
    细数二十世纪最伟大的十大算法
    error: *** No iconv() implementation found in C library & libiconv 交叉编译 失败编译
    gnulib+glib+glibc+libc的不同转
    [Android] 以singleInstance模式加载的Activity怎么接收以Bundle方式传递过来的参数 By onNewIntent() but not onResum
    Glib在armlinux下的交叉编译
    python 笔记
  • 原文地址:https://www.cnblogs.com/lavender-pansy/p/13834451.html
Copyright © 2011-2022 走看看