zoukankan      html  css  js  c++  java
  • tensorflow2自定义损失函数

    tensorflow2自定义损失函数

    一、总结

    一句话总结:

    直接定义函数,然后在compile时传给loss即可
    def customized_mse(y_true, y_pred):
        return tf.reduce_mean(tf.square(y_pred - y_true))
    
    model = keras.models.Sequential([
        keras.layers.Dense(30, activation='relu',
                           input_shape=x_train.shape[1:]),
        keras.layers.Dense(1),
    ])
    model.summary()
    model.compile(loss=customized_mse, optimizer="sgd",
                  metrics=["mean_squared_error"])

    二、tensorflow2自定义损失函数

    转自或参考:tensorflow2.x学习笔记十七:自定义网络层、模型以及损失函数
    https://blog.csdn.net/qq_39507748/article/details/105256541

    一、自定义网络层layer

    • 继承tf.keras.layers.Layer
    • 使用tf.keras.layers.Lambda

    下面这个例子就包含了以上两种形式:

    from tensorflow import keras
    
    class CustomizedDenseLayer(keras.layers.Layer):
        def __init__(self, units, activation=None, **kwargs):
            self.units = units
            self.activation = keras.layers.Activation(activation)
            super(CustomizedDenseLayer, self).__init__(**kwargs)
        
        def build(self, input_shape):
            """构建所需要的参数"""
            # x * w + b. input_shape:[None, a] w:[a,b]output_shape: [None, b]
            self.kernel = self.add_weight(name = 'kernel',
                                          shape = (input_shape[1], 
                                                   self.units),
                                          initializer = 'uniform',
                                          trainable = True)
            self.bias = self.add_weight(name = 'bias',
                                        shape = (self.units, ),
                                        initializer = 'zeros',
                                        trainable = True)
            super(CustomizedDenseLayer, self).build(input_shape)
        
        def call(self, x):
            """完成正向计算"""
            return self.activation(x @ self.kernel + self.bias)
    
    customized_softplus = keras.layers.Lambda(lambda x : tf.nn.softplus(x))
    
    model = keras.models.Sequential([
        CustomizedDenseLayer(30, activation='relu',
                             input_shape=x_train.shape[1:]),
        CustomizedDenseLayer(1),
        customized_softplus,
    ])
    

    二、自定义模型model

    • 继承tf.keras.Model
    import tensorflow as tf
    
    class MyModel(tf.keras.Model):
    
      def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
        self.dropout = tf.keras.layers.Dropout(0.5)
    
      def call(self, inputs, training=False):
         x = self.dense1(inputs)
         if training:
            x = self.dropout(x, training=training)
         return self.dense2(x)
    
    model = MyModel()
    

    三、自定义损失函数loss

    直接定义函数,然后在compile时传给loss参数即可

    def customized_mse(y_true, y_pred):
        return tf.reduce_mean(tf.square(y_pred - y_true))
    
    model = keras.models.Sequential([
        keras.layers.Dense(30, activation='relu',
                           input_shape=x_train.shape[1:]),
        keras.layers.Dense(1),
    ])
    model.summary()
    model.compile(loss=customized_mse, optimizer="sgd",
                  metrics=["mean_squared_error"])
    
     
    我的旨在学过的东西不再忘记(主要使用艾宾浩斯遗忘曲线算法及其它智能学习复习算法)的偏公益性质的完全免费的编程视频学习网站: fanrenyi.com;有各种前端、后端、算法、大数据、人工智能等课程。
    博主25岁,前端后端算法大数据人工智能都有兴趣。
    大家有啥都可以加博主联系方式(qq404006308,微信fan404006308)互相交流。工作、生活、心境,可以互相启迪。
    聊技术,交朋友,修心境,qq404006308,微信fan404006308
    26岁,真心找女朋友,非诚勿扰,微信fan404006308,qq404006308
    人工智能群:939687837

    作者相关推荐

  • 相关阅读:
    web服务器-Apache
    nginx优化
    nginx下载限速
    nginx-URL重写
    HDU 5358 First One 求和(序列求和,优化)
    HDU 5360 Hiking 登山 (优先队列,排序)
    HDU 5353 Average 糖果分配(模拟,图)
    UVALive 4128 Steam Roller 蒸汽式压路机(最短路,变形) WA中。。。。。
    HDU 5348 MZL's endless loop 给边定向(欧拉回路,最大流)
    HDU 5344 MZL's xor (水题)
  • 原文地址:https://www.cnblogs.com/Renyi-Fan/p/13443974.html
Copyright © 2011-2022 走看看