zoukankan      html  css  js  c++  java
  • numpy.argmax 用在求解混淆矩阵用

    numpy.argmax

    numpy.argmax(a, axis=None, out=None)[source]

    Returns the indices of the maximum values along an axis.

    Parameters:

    a : array_like

    Input array.

    axis : int, optional

    By default, the index is into the flattened array, otherwise along the specified axis.

    out : array, optional

    If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.

    Returns:

    index_array : ndarray of ints

    Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

    See also

    ndarray.argmax, argmin

    amax
    The maximum value along a given axis.
    unravel_index
    Convert a flat index into an index tuple.

    Notes

    In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

    Examples

    >>> a = np.arange(6).reshape(2,3)
    >>> a
    array([[0, 1, 2],
           [3, 4, 5]])
    >>> np.argmax(a)
    5
    >>> np.argmax(a, axis=0)
    array([1, 1, 1])
    >>> np.argmax(a, axis=1)
    array([2, 2])
    
    >>> b = np.arange(6)
    >>> b[1] = 5
    >>> b
    array([0, 5, 2, 3, 4, 5])
    >>> np.argmax(b) # Only the first occurrence is returned.
    1

    在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别
    if __name__ == "__main__":
        width, height = 32, 32
        X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
        trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
        print("sample data:")
        print(trainX[0])
        print(trainY[0])
        print(testX[-1])
        print(testY[-1])
    
        model = get_model(width, height, classes=100)
    
        filename = 'cnn_handwrite-acc0.8.tflearn'
        # try to load model and resume training
        #try:
        #    model.load(filename)
        #    print("Model loaded OK. Resume training!")
        #except:
        #    pass
    
        # Initialize our callback with desired accuracy threshold.
        early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6)
        try:
            model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                      snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                      show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
        except StopIteration as e:
            print("OK, stop iterate!Good!")
    
        model.save(filename)
    
        # predict all data and calculate confusion_matrix
        model.load(filename)
    
        pro_arr =model.predict(X)
        predict_labels = np.argmax(pro_arr, axis=1)
        print(classification_report(org_labels, predict_labels))
        print(confusion_matrix(org_labels, predict_labels))
    
  • 相关阅读:
    文件I/O(不带缓冲)之write函数
    文件I/O(不带缓冲)之read函数
    webpack4.x版本splitChunksPlugin的配置项详解与实际应用场景
    关于使用express作为spa应用服务的问题
    url 的正则表达式:path-to-regexp
    node.js、js读取excel、操作excel、创建excel之js-xlsx.js
    Web前端之iframe详解
    html中的meta标签是什么?有哪些属性?
    大型互联网架构概述,看完文章又涨知识了
    redis 的过期策略都有哪些?内存淘汰机制都有哪些?
  • 原文地址:https://www.cnblogs.com/bonelee/p/8976380.html
Copyright © 2011-2022 走看看