zoukankan      html  css  js  c++  java
  • tensorflow2.0——Resnet网络设计代码

    import tensorflow as tf
    
    
    class BasicBlock(tf.keras.layers.Layer):
        def __init__(self, filter_num, stride=1):
            super(BasicBlock, self).__init__()
    
            self.conv1 = tf.keras.layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
            self.bn1 = tf.keras.layers.BatchNormalization()
            self.relu = tf.keras.layers.Activation('relu')
    
            self.conv2 = tf.keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
            self.bn2 = tf.keras.layers.BatchNormalization()
    
            if stride != 1:
                self.downsample = tf.keras.Sequential()
                self.downsample.add(tf.keras.layers.Conv2D(filter_num, (1, 1), strides=stride))
            else:
                self.downsample = lambda x: x
    
        def call(self, inputs, training=None):
            out = self.conv1(inputs)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            indentity = self.downsample(inputs)
            output = tf.keras.layers.add([out, indentity])
            output = tf.nn.relu(output)
    
            return output
    
    
    class ResNet(tf.keras.Model):
        def __init__(self, layer_dims, num_classes=100):  # layer_dims=[2,2,2,2]  表示有4个resblock,每个resblock包含两个basicbloock
                                                            #   num_classes = 100   表示最后的分类有100个
            super(ResNet, self).__init__()
            self.stem = tf.keras.Sequential([
                tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1)),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Activation('relu'),
                tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=[2, 2], padding='same')
            ])
    
            self.layer1 = self.build_resblock(64, layer_dims[0])
            self.layer2 = self.build_resblock(128, layer_dims[1],stride=2)
            self.layer3 = self.build_resblock(256, layer_dims[2],stride=2)
            self.layer4 = self.build_resblock(512, layer_dims[3],stride=2)
            #   output[b,512,h,w] 将最后的h*w平均为1个值,这样最后就只有[b,512,1,1]
            self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
            self.fc = tf.keras.layers.Dense(num_classes)
    
        def call(self, inputs, training=None):
            x = self.stem(inputs)
    
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
    
            #   [b,c]
            x = self.avgpool(x)
            #   [b,100]     100是现在设置的num_classes=100
            x = self.fc(x)
            return x
    
        def build_resblock(self, filter_num, blocks, stride=1):
            res_blocks = tf.keras.Sequential()
            res_blocks.add(BasicBlock(filter_num, stride=stride))
    
            for i in range(1,blocks):
                res_blocks.add(BasicBlock(filter_num, stride=1))
    
            return res_blocks
    
    def resnet18():
        return ResNet([2,2,2,2])
    
    def resnet34():
        return ResNet([3,4,6,3])
  • 相关阅读:
    git
    Java命令行参数解析
    Java调用本地命令
    理解JavaScript继承
    python selenium自动化(三)Chrome Webdriver的兼容
    python selenium自动化(二)自动化注册流程
    python selenium自动化(一)点击页面链接测试
    使用python selenium进行自动化functional test
    JUnit中测试异常抛出的方法
    爬坑 http协议的options请求
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13758763.html
Copyright © 2011-2022 走看看