zoukankan      html  css  js  c++  java
  • Keras class_weight和sample_weight用法

    搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras

    import tensorflow as tf
    import numpy as np
    
    data_size = 100
    input_size=3
    classes=3
    
    x_train = np.random.rand(data_size ,input_size)
    y_train= np.random.randint(0,classes,data_size )
    #sample_weight_train = np.random.rand(data_size)
    x_val = np.random.rand(data_size ,input_size)
    y_val= np.random.randint(0,classes,data_size )
    #sample_weight_val = np.random.rand(data_size )
    
    inputs = tf.keras.layers.Input(shape=(input_size))
    pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=pred)
    
    loss = tf.keras.losses.sparse_categorical_crossentropy
    metrics = tf.keras.metrics.sparse_categorical_accuracy
    
    model.compile(loss=loss , metrics=[metrics], optimizer='adam')
    
    # Make model static, so we can compare it between different scenarios
    for layer in model.layers:
        layer.trainable = False
    
    # base model no weights (same result as without class_weights)
    # model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
    class_weights={0:1.,1:1.,2:1.}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the class weights to zero, to check which loss and metric that is affected
    class_weights={0:0,1:0,2:0}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the sample_weights to zero, to check which loss and metric that is affected
    sample_weight_train = np.zeros(100)
    sample_weight_val = np.zeros(100)
    model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100
    

    class_weight: output 变量的权重
    sample_weight: data sample 的权重

  • 相关阅读:
    最简单的非交互ssh远程执行命令expect脚本
    [转]解决Adobe Reader X中金山词霸不能取词故障
    《TCP/IP Sockets 编程》笔记1
    Visual Studio 2005中无法调试CLR C++的枚举类型
    查找字符串中字符间不同的最大子串
    《C++ Primer》关于自增自减操作符的描述错误
    LVM逻辑卷管理
    《TCP/IP Sockets 编程》笔记2
    《TCP/IP Sockets 编程》笔记7
    Linux文件
  • 原文地址:https://www.cnblogs.com/yaos/p/14014218.html
Copyright © 2011-2022 走看看