zoukankan      html  css  js  c++  java
  • 【小白学PyTorch】18 TF2构建自定义模型

    【机器学习炼丹术】的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617

    参考目录:

    之前讲过了如何用tensorflow构建数据集,然后这一节课讲解如何用Tensorflow2.0来创建模型。

    TF2.0中创建模型的API基本上都放到了它的Keras中了,Keras可以理解为TF的高级API,里面封装了很多的常见网络层、常见损失函数等。 后续会详细介绍keras的全面功能,本篇文章讲解如何构建模型。

    1 创建自定义网络层

    import tensorflow as tf
    import tensorflow.keras as keras
    
    class MyLayer(keras.layers.Layer):
        def __init__(self, input_dim=32, output_dim=32):
            super(MyLayer, self).__init__()
    
            w_init = tf.random_normal_initializer()
            self.weight = tf.Variable(
                initial_value=w_init(shape=(input_dim, output_dim), dtype=tf.float32),
                trainable=True) # 如果是false则是不参与梯度下降的变量
    
            b_init = tf.zeros_initializer()
            self.bias = tf.Variable(initial_value=b_init(
                shape=(output_dim), dtype=tf.float32), trainable=True)
    
        def call(self, inputs):
            return tf.matmul(inputs, self.weight) + self.bias
    
    
    x = tf.ones((3,5))
    my_layer = MyLayer(input_dim=5,
                       output_dim=10)
    out = my_layer(x)
    print(out.shape)
    >>> (3, 10)
    

    这个就是定义了一个TF的网络层,其实可以看出来和PyTorch定义的方式非常的类似:

    • 这个类要继承tf.keras.layers.Layer,这个pytorch中要继承torch.nn.Module类似;
    • 网络层的组件在__def__中定义,和pytorch的模型类相同;
    • call()和pytorch中的forward()的类似。

    上面代码中实现的是一个全连接层的定义,其中可以看到使用tf.random_normal_initializer()来作为参数的初始化器,然后用tf.Variable来产生网络层中的权重变量,通过trainable=True这个参数说明这个权重变量是一个参与梯度下降的可以训练的变量。

    我通过tf.ones((3,5))产生一个shape为[3,5]的一个全是1的张量,这里面第一维度的3表示有3个样本,第二维度的5就是表示要放入全连接层的数据(全连接层的输入是5个神经元);然后设置的全连接层的输出神经元数量是10,所以最后的输出是(3,10)。

    2 创建一个完整的CNN

    import tensorflow as tf
    import tensorflow.keras as keras
    
    class CBR(keras.layers.Layer):
        def __init__(self,output_dim):
            super(CBR,self).__init__()
            self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
            self.bn = keras.layers.BatchNormalization(axis=3)
            self.ReLU = keras.layers.ReLU()
    
        def call(self, inputs):
            inputs = self.conv(inputs)
            inputs = self.ReLU(self.bn(inputs))
            return inputs
    
    class MyNet(keras.Model):
        def __init__ (self,input_dim=3):
            super(MyNet,self).__init__()
            self.cbr1 = CBR(16)
            self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
            self.cbr2 = CBR(32)
            self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2))
    
        def call(self, inputs):
            inputs = self.maxpool1(self.cbr1(inputs))
            inputs = self.maxpool2(self.cbr2(inputs))
            return inputs
    
    model = MyNet(3)
    data = tf.random.normal((16,224,224,3))
    output = model(data)
    print(output.shape)
    >>> (16, 56, 56, 32)
    

    这个是构建了一个非常简单的卷积网络,结构是常见的:卷积层+BN层+ReLU层。可以发现这里继承的一个tf.keras.Model这个类。

    2.1 keras.Model vs keras.layers.Layer

    Model比Layer的功能更多,反过来说,Layer的功能更精简专一。

    • Layer:仅仅用作张量的操作,输入一个张量,输出也要求是一个张量,对张量的操作都可以用Layer来封装;
    • Model:一个更加复杂的结构,由多个Layer组成。 Model的话,可以使用.fit(),.evaluate().predict()等方法来快速训练。保存和加载模型也是在Model这个级别进行的。

    现在说一说上面的代码和pytorch中的区别,作为一个对比学习、也作为一个对pytorch的回顾:

    • 卷积层Conv2D中,Keras中不用输入输入的通道数,filters就是卷积后的输出特征图的通道数;而PyTorch的卷积层是需要输入两个通道数的参数,一个是输入特征图的通道数,一个是输出特征图的通道数;
    • keras.layers.BatchNormalization(axis=3)是BN层,这里的axis=3说明第三个维度(从0开始计数)是通道数,是需要作为批归一化的维度(这个了解BN算法的朋友应该可以理解吧,不了解的话去重新看我之前剖析BN层算法的那个文章吧,在文章末尾有相关链接)。pytorch的图像的四个维度是:

    [【样本数量,通道数,width,height】 ]

    而tensorflow是:

    [【样本数量,width,height,通道数】 ]

    总之,学了pytorch之后,再看keras的话,对照的keras的API,很多东西都直接就会了,两者的API越来越相似了。

    上面最后输出是(16, 56, 56, 32),输入的是(224 imes 224)的维度,然后经过两个最大池化层,就变成了(56 imes 56)了。

    到此为止,我们现在应该是可以用keras来构建模型了。

  • 相关阅读:
    How to alter department in PMS system
    Can't create new folder in windows7
    calculate fraction by oracle
    Long Wei information technology development Limited by Share Ltd interview summary.
    ORACLE BACKUP AND RECOVERY
    DESCRIBE:When you mouse click right-side is open an application and click left-side is attribution.
    ORACLE_TO_CHAR Function
    电脑BOIS设置
    JSP点击表头排序
    jsp+js实现可排序表格
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13767150.html
Copyright © 2011-2022 走看看