zoukankan      html  css  js  c++  java
  • 搭建resnet

    构建resnet的基本单元(basicblock)如下图所示,可以看到,输入张量input在基本单元中会经过两个通路,一个通路由两个卷积层构成,另一个通路是跳接。

    resnet的整体结构可以参考 resnet34,可以看到resnet34也是由上述basicblock串联得到,需要注意的是:

    (1) 在某些basicblock中,第一个卷积层会对input做shape上的变换,即对输入图片的长宽减一半、通道加一倍

    (2) 如果input在卷积层发生了shape上的变换,则对应的跳接线也需要将input进行shape上的变换(通过大小为(1,1)的卷积核)

    (3) 在所有basicblock中,第二个卷积层相对第一个卷积层不对数据做shape上的变换

    (4) 在所有basicblock中,卷积核的大小保持(3,3)不变,改变图片长宽依靠的是卷积步长

    接下来先来用tensorflow搭建一个basicblock类,这个类通过它的参数stride来判定跳接线是“实线”还是“虚线”。实例化basicblock类时只需指定卷积网络的三个基本要素—— 卷积核大小、卷积核数量、卷积步长,即可得到一个basicblock对象。下面的代码构建了BasicBlock类,并使用BasicBlock搭建了一个简单的resnet

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras import Model, layers
    from tensorflow.keras.callbacks import TensorBoard
    
    class BasicBlock(Model):
        def __init__(self, filter_num, stride=1, filter_size=(3,3)):
            super(BasicBlock,self).__init__()
            # 卷积通路
            self.conv1 = layers.Conv2D(filters=filter_num, kernel_size=filter_size, strides=(stride,stride), padding='same')
            self.relu = layers.Activation('relu')
            self.conv2 = layers.Conv2D(filters=filter_num, kernel_size=filter_size, strides=(1,1), padding='same')
            # 跳接线
            if stride == 1:
                self.sortcut = lambda x: x
            else:
                self.sortcut = layers.Conv2D(filters=filter_num, kernel_size=(1,1), strides=(stride,stride), padding='same')
    
        def call(self, inputs):
            out = self.conv1(inputs)
            out = self.relu(out)
            out = self.conv2(out)
            identity = self.sortcut(inputs)
            out = layers.add([out,identity])
            return out
    
    # 用BasicBlock搭建一个简单的网络
    inputs = layers.Input(shape=(28,28,3))
    
    out = BasicBlock(filter_num=8, stride=2).call(inputs)
    out = BasicBlock(filter_num=8, stride=1).call(out)
    out = BasicBlock(filter_num=8, stride=1).call(out)
    
    out = BasicBlock(filter_num=16, stride=2).call(out)
    out = BasicBlock(filter_num=16, stride=1).call(out)
    out = BasicBlock(filter_num=16, stride=1).call(out)
    
    out = layers.Flatten()(out)
    out = layers.Dense(1)(out)
    model = Model(inputs=inputs, outputs=out)
    
    model.compile(optimizer='adam',loss='mse')
    Tensorboard = TensorBoard(log_dir='.\logs', histogram_freq=1)
    x = np.random.rand(100,28,28,3)
    y = np.random.rand(100,1)
    model.fit(x,y,epochs=100,callbacks=[Tensorboard],verbose=2)
    

    截取该resnet中的两个BasicBlock,如下图所示

  • 相关阅读:
    JSP error: Only a type can be imported
    关于JAVA插入Mysql数据库中文乱码问题解决方案
    MySQL SQL优化——分片搜索
    myeclipse 调试JSP页面
    jsp:usebean 常用注意事项
    spring XML格式
    VB 要求对象
    VB 对象变量或with块变量未设置
    Spring依赖注入
    Spring 读取XML配置文件的两种方式
  • 原文地址:https://www.cnblogs.com/bill-h/p/14139471.html
Copyright © 2011-2022 走看看