zoukankan      html  css  js  c++  java
  • 【516】keras 源码分析之 Dense

    参考:keras源码分析之Layer

    参考:keras源码分析之Dense


      本文主要讲解一下 Dense 层的源码,Dense 层即最常用的全连接层,代码很简单,主要是重写了 build 与 call 方法,在我们自定义 Layer 时,也可以参考该层的实现。但是不需要这么复杂,只要写出必要的部分就可以了,参见下一篇博客。

    1. Layer 类的相关说明

    参考:TensorFlow函数:tf.layers.Layer —— W3Cschool TensorFlow 官方文档

    参考:关于 Keras 网络层 —— keras 中文文档

      基础层类。这是所有层都继承的类,实现了通用的基础结构功能。层是实现常见神经网络操作的类,例如卷积、批量规范等。这些操作需要管理变量、损失和更新,以及将 TensorFlow 操作应用于输入张量。用户只需实例化它,然后将其视为可调用的。

      我们建议 Layer 的子代实现以下方法:

    • __init__ ():在成员变量中保存配置
    • build():当我们知道输入和 dtype 的形状时,从 __call__ 调用一次。应该有对 add_variable() 的调用,然后调用高级的 build() (设置为 self.built = True,这在用户想要在第一个 __call__ 之前手动调用 build() 时很好)。
    • * call():确认 build() 已被调用一次后调用 __call__。实际上应该执行将层应用于输入张量的逻辑(应该作为第一个参数传入)。

    2. Dense 源码解读

    2.1 __init__ 函数重写

      构造方法没什么好说的,就是一些简单的赋值。

    from keras.layers import Layer
    
    class Dense(Layer):
        def __init__(self, units,
                     activation=None,
                     use_bias=True,
                     kernel_initializer='glorot_uniform',
                     bias_initializer='zeros',
                     kernel_regularizer=None,
                     bias_regularizer=None,
                     activity_regularizer=None,
                     kernel_constraint=None,
                     bias_constraint=None,
                     **kwargs):
            if 'input_shape' not in kwargs and 'input_dim' in kwargs:
                kwargs['input_shape'] = (kwargs.pop('input_dim'),)
            super(Dense, self).__init__(**kwargs)
            self.units = units
            self.activation = activations.get(activation)
            self.use_bias = use_bias
            self.kernel_initializer = initializers.get(kernel_initializer)
            self.bias_initializer = initializers.get(bias_initializer)
            self.kernel_regularizer = regularizers.get(kernel_regularizer)
            self.bias_regularizer = regularizers.get(bias_regularizer)
            self.activity_regularizer = regularizers.get(activity_regularizer)
            self.kernel_constraint = constraints.get(kernel_constraint)
            self.bias_constraint = constraints.get(bias_constraint)
            self.input_spec = InputSpec(min_ndim=2)
            self.supports_masking = True
    

      

    2.2 build 函数重写

      build 方法中定义了两个 Variable 即权重,最后把 built 参数置为 True。

        def build(self, input_shape):
            assert len(input_shape) >= 2
            # 维度取 input_shape 的最后一维
            # 正好进行后面的叉乘
            input_dim = input_shape[-1]
    
            # 设置权重矩阵,维度为 (input_dim, self.units),用于叉乘   
            self.kernel = self.add_weight(shape=(input_dim, self.units),
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)
            if self.use_bias:
                # 设置偏置
                self.bias = self.add_weight(shape=(self.units,),
                                            initializer=self.bias_initializer,
                                            name='bias',
                                            regularizer=self.bias_regularizer,
                                            constraint=self.bias_constraint)
            else:
                self.bias = None
            self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
            self.built = True
    

      

    2.3 call 函数重写

      call 方法把输入值与 build 方法中定义的权重进行了点积的操作,然后与 build 中的偏移量进行相加,最后经过激活函数返回最终的输出结果。

        def call(self, inputs):
            # 具体的 矩阵操作
            output = K.dot(inputs, self.kernel)
            if self.use_bias:
                output = K.bias_add(output, self.bias, data_format='channels_last')
            if self.activation is not None:
                output = self.activation(output)
            return output
    

      

    2.4 compute_output_shape 函数重写

      计算出输出tensor的维度并返回。

        def compute_output_shape(self, input_shape):
            assert input_shape and len(input_shape) >= 2
            assert input_shape[-1]
            output_shape = list(input_shape)
            output_shape[-1] = self.units
            return tuple(output_shape)
    

      

    2.5 get_config 函数重写

      保留一些中间值并以字典的形式返回。

        def get_config(self):
            config = {
                'units': self.units,
                'activation': activations.serialize(self.activation),
                'use_bias': self.use_bias,
                'kernel_initializer': initializers.serialize(self.kernel_initializer),
                'bias_initializer': initializers.serialize(self.bias_initializer),
                'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
                'bias_regularizer': regularizers.serialize(self.bias_regularizer),
                'activity_regularizer':
                    regularizers.serialize(self.activity_regularizer),
                'kernel_constraint': constraints.serialize(self.kernel_constraint),
                'bias_constraint': constraints.serialize(self.bias_constraint)
            }
            base_config = super(Dense, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))
    

      

  • 相关阅读:
    小知识点
    异常关机后idea的注入不能使用
    day42_mysql 数据库操作 数据库的约束
    day41_mysql安装与卸载 mysql配置 SQL语句 DDL:操作数据库,表 DML:增删改表中的记录 DQL:查询表中的记录 DCL:管理用户与授权
    day39_ECMAScript BOM DOM
    day38_JS
    day35_HTML inpot标签 form表单
    day33_Stream(JDK1.8后的接口,本身不是函数式接口)
    day32_ 优化文件上传及接收 函数式接口 自定义函数接口 函数式编程 常用函数式接口 Stream流
    day31_网络通信三要素 TCP Socket关键字 ServerSocket
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/14219978.html
Copyright © 2011-2022 走看看