zoukankan      html  css  js  c++  java
  • Keras class_weight和sample_weight用法

    搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras

    import tensorflow as tf
    import numpy as np
    
    data_size = 100
    input_size=3
    classes=3
    
    x_train = np.random.rand(data_size ,input_size)
    y_train= np.random.randint(0,classes,data_size )
    #sample_weight_train = np.random.rand(data_size)
    x_val = np.random.rand(data_size ,input_size)
    y_val= np.random.randint(0,classes,data_size )
    #sample_weight_val = np.random.rand(data_size )
    
    inputs = tf.keras.layers.Input(shape=(input_size))
    pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=pred)
    
    loss = tf.keras.losses.sparse_categorical_crossentropy
    metrics = tf.keras.metrics.sparse_categorical_accuracy
    
    model.compile(loss=loss , metrics=[metrics], optimizer='adam')
    
    # Make model static, so we can compare it between different scenarios
    for layer in model.layers:
        layer.trainable = False
    
    # base model no weights (same result as without class_weights)
    # model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
    class_weights={0:1.,1:1.,2:1.}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the class weights to zero, to check which loss and metric that is affected
    class_weights={0:0,1:0,2:0}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the sample_weights to zero, to check which loss and metric that is affected
    sample_weight_train = np.zeros(100)
    sample_weight_val = np.zeros(100)
    model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100
    

    class_weight: output 变量的权重
    sample_weight: data sample 的权重

  • 相关阅读:
    Linux网络基础配置
    UVA 116 Unidirectional TSP(dp + 数塔问题)
    修改Hosts文件
    倒排索引
    可以把阿里云上面的一些介绍和视频都看看
    练练脑,继续过Hard题目
    explicit的用法
    auto_ptr的使用和注意
    我写的快排程序
    快速排序、查第k大
  • 原文地址:https://www.cnblogs.com/yaos/p/12069527.html
Copyright © 2011-2022 走看看