zoukankan      html  css  js  c++  java
  • 读取多张MNIST图片与利用BaseEstimator基类创建分类器

    读取多张MNIST图片

    在读取多张MNIST图片之前,我们先来看下读取单张图片如何实现

    每张数字图片大小都为28 * 28的,需要将数据reshape成28 * 28的,采用最近邻插值,如下

    def plot_digit(data):
        img = data.reshape(28,28)
        plt.imshow(img,cmap=matplotlib.cm.binary,interpolation='nearest')
        plt.axis('off')
    import matplotlib.pyplot as plt
    import matplotlib
    some_digit = X[36000]
    plot_digit(some_digit)

     现在来读取多张MNIST图片

    需要确定每行显示多少张图片,根据照片数最多显示几行,最后一行有几个未填满,将每行进行连接起来

    def plot_digits(instances,images_per_row = 10,**options):
        size = 28
        images_per_row = min(len(instances),images_per_row)
        images = [instance.reshape(size,size) for instance in instances]
        n_rows = (len(instances) - 1) // images_per_row +1
        row_images = []
        n_empty = n_rows * images_per_row - len(instances)
        images.append(np.zeros((size,size*n_empty)))
        for row in range(n_rows):
            rimages = images[row * images_per_row:(row+1) * images_per_row]
            row_images.append(np.concatenate(rimages,axis=1))
        image = np.concatenate(row_images,axis=0)
        plt.imshow(image,cmap=matplotlib.cm.binary,**options)
        plt.axis('off')
    import numpy as np
    import os
    
    # to make this notebook's output stable across runs
    np.random.seed(42)
    
    # To plot pretty figures
    %matplotlib inline
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    mpl.rc('axes', labelsize=14)
    mpl.rc('xtick', labelsize=12)
    mpl.rc('ytick', labelsize=12)
    
    # Where to save the figures
    PROJECT_ROOT_DIR = "."
    #CHAPTER_ID = "classification"
    
    def save_fig(fig_id, tight_layout=True):
        path = os.path.join(PROJECT_ROOT_DIR, "images", fig_id + ".png")
        print("Saving figure", fig_id)
        if tight_layout:
            plt.tight_layout()
        plt.savefig(path, format='png', dpi=300)
    plt.figure(figsize=(9,9))
    example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]
    plot_digits(example_images, images_per_row=10)
    save_fig("more_digits_plot")
    plt.show()

    显示并将结果存入磁盘

    利用BaseEstimator基类创建分类器

    在做非5分类器的交叉验证时,需要写一个非5的分类器

    估计器(Estimator)很多时候可以直接理解成分类器,主要包括两个函数

    • fit():训练算法,设置内部参数,接受训练集和类别两个参数
    • predict():预测测试集类别,参数为测试集

    大多数sklearn估计器接受和输出的数据格式均为numpy数组或类似格式

    from sklearn.base import BaseEstimator
    class Never5Classifier(BaseEstimator):
        def fit(self,X,y = None):
            pass
        def predict(self,X):
            return np.zeros((len(X),1),dtype = bool)
    never_5_clf = Never5Classifier()
    cross_val_score(never_5_clf,X_train,y_train_5,cv = 3,scoring='accuracy')

    Never5Classifier分类器预测的结果都是0,而数字为5的标签应该都为1,非5的为0,这时候可以看出也有90%的可能性猜对某张图片不是5

    关于评估器以及转换器、流水线(Pipline)等更多参考:https://www.jianshu.com/p/516f009c0875

  • 相关阅读:
    黑马程序员__OC三大特性
    黑马程序员___OC类和对象
    黑马程序员___预处理指令
    黑马程序员___数据类型总结
    黑马程序员__指针
    黑马程序员__C语言__函数__static和extern
    黑马程序员__C语言__流程控制__选择结构
    黑马程序员__C语言__循环结构
    入园随笔
    Fiddler中抓取不到Jmeter模拟的请求包。
  • 原文地址:https://www.cnblogs.com/whiteBear/p/12341094.html
Copyright © 2011-2022 走看看