zoukankan      html  css  js  c++  java
  • 读取多张MNIST图片与利用BaseEstimator基类创建分类器

    读取多张MNIST图片

    在读取多张MNIST图片之前,我们先来看下读取单张图片如何实现

    每张数字图片大小都为28 * 28的,需要将数据reshape成28 * 28的,采用最近邻插值,如下

    def plot_digit(data):
        img = data.reshape(28,28)
        plt.imshow(img,cmap=matplotlib.cm.binary,interpolation='nearest')
        plt.axis('off')
    import matplotlib.pyplot as plt
    import matplotlib
    some_digit = X[36000]
    plot_digit(some_digit)

     现在来读取多张MNIST图片

    需要确定每行显示多少张图片,根据照片数最多显示几行,最后一行有几个未填满,将每行进行连接起来

    def plot_digits(instances,images_per_row = 10,**options):
        size = 28
        images_per_row = min(len(instances),images_per_row)
        images = [instance.reshape(size,size) for instance in instances]
        n_rows = (len(instances) - 1) // images_per_row +1
        row_images = []
        n_empty = n_rows * images_per_row - len(instances)
        images.append(np.zeros((size,size*n_empty)))
        for row in range(n_rows):
            rimages = images[row * images_per_row:(row+1) * images_per_row]
            row_images.append(np.concatenate(rimages,axis=1))
        image = np.concatenate(row_images,axis=0)
        plt.imshow(image,cmap=matplotlib.cm.binary,**options)
        plt.axis('off')
    import numpy as np
    import os
    
    # to make this notebook's output stable across runs
    np.random.seed(42)
    
    # To plot pretty figures
    %matplotlib inline
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    mpl.rc('axes', labelsize=14)
    mpl.rc('xtick', labelsize=12)
    mpl.rc('ytick', labelsize=12)
    
    # Where to save the figures
    PROJECT_ROOT_DIR = "."
    #CHAPTER_ID = "classification"
    
    def save_fig(fig_id, tight_layout=True):
        path = os.path.join(PROJECT_ROOT_DIR, "images", fig_id + ".png")
        print("Saving figure", fig_id)
        if tight_layout:
            plt.tight_layout()
        plt.savefig(path, format='png', dpi=300)
    plt.figure(figsize=(9,9))
    example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]
    plot_digits(example_images, images_per_row=10)
    save_fig("more_digits_plot")
    plt.show()

    显示并将结果存入磁盘

    利用BaseEstimator基类创建分类器

    在做非5分类器的交叉验证时,需要写一个非5的分类器

    估计器(Estimator)很多时候可以直接理解成分类器,主要包括两个函数

    • fit():训练算法,设置内部参数,接受训练集和类别两个参数
    • predict():预测测试集类别,参数为测试集

    大多数sklearn估计器接受和输出的数据格式均为numpy数组或类似格式

    from sklearn.base import BaseEstimator
    class Never5Classifier(BaseEstimator):
        def fit(self,X,y = None):
            pass
        def predict(self,X):
            return np.zeros((len(X),1),dtype = bool)
    never_5_clf = Never5Classifier()
    cross_val_score(never_5_clf,X_train,y_train_5,cv = 3,scoring='accuracy')

    Never5Classifier分类器预测的结果都是0,而数字为5的标签应该都为1,非5的为0,这时候可以看出也有90%的可能性猜对某张图片不是5

    关于评估器以及转换器、流水线(Pipline)等更多参考:https://www.jianshu.com/p/516f009c0875

  • 相关阅读:
    希腊字母写法
    The ASP.NET MVC request processing line
    lambda aggregation
    UVA 10763 Foreign Exchange
    UVA 10624 Super Number
    UVA 10041 Vito's Family
    UVA 10340 All in All
    UVA 10026 Shoemaker's Problem
    HDU 3683 Gomoku
    UVA 11210 Chinese Mahjong
  • 原文地址:https://www.cnblogs.com/whiteBear/p/12341094.html
Copyright © 2011-2022 走看看