zoukankan      html  css  js  c++  java
  • Keras class_weight和sample_weight用法

    搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras

    import tensorflow as tf
    import numpy as np
    
    data_size = 100
    input_size=3
    classes=3
    
    x_train = np.random.rand(data_size ,input_size)
    y_train= np.random.randint(0,classes,data_size )
    #sample_weight_train = np.random.rand(data_size)
    x_val = np.random.rand(data_size ,input_size)
    y_val= np.random.randint(0,classes,data_size )
    #sample_weight_val = np.random.rand(data_size )
    
    inputs = tf.keras.layers.Input(shape=(input_size))
    pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=pred)
    
    loss = tf.keras.losses.sparse_categorical_crossentropy
    metrics = tf.keras.metrics.sparse_categorical_accuracy
    
    model.compile(loss=loss , metrics=[metrics], optimizer='adam')
    
    # Make model static, so we can compare it between different scenarios
    for layer in model.layers:
        layer.trainable = False
    
    # base model no weights (same result as without class_weights)
    # model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
    class_weights={0:1.,1:1.,2:1.}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the class weights to zero, to check which loss and metric that is affected
    class_weights={0:0,1:0,2:0}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the sample_weights to zero, to check which loss and metric that is affected
    sample_weight_train = np.zeros(100)
    sample_weight_val = np.zeros(100)
    model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100
    

    class_weight: output 变量的权重
    sample_weight: data sample 的权重

  • 相关阅读:
    api接口统一管理
    axios封装
    事件监听和事件模型
    W3C标准
    Redis安装(PHPredis服务+windows的redis环境)
    Redis介绍
    jQuery ajax方法小结
    博客园鼠标特效
    PHP---截取七牛地址中的文件名
    jQuery---显示和隐藏
  • 原文地址:https://www.cnblogs.com/yaos/p/12069527.html
Copyright © 2011-2022 走看看