zoukankan      html  css  js  c++  java
  • 【cs231n】图像分类-Nearest Neighbor Classifier(最近邻分类器)【python3实现】

    【学习自CS231n课程】

     转载请注明出处:http://www.cnblogs.com/GraceSkyer/p/8735908.html

    图像分类:

      一张图像的表示:长度、宽度、通道(3个颜色通道,分别是红R、绿G、蓝B)。

      对于计算机来说,图像是一个由数字组成的巨大的三维数组,数组元素是取值范围从0到255的整数,其中0表示全黑,255表示全白。

    图像分类的任务对于一个给定的图像,预测它属于的那个分类标签。

    如何写图像分类算法呢?

    数据驱动方法:

    收集足够代表性的样本(数据),运用数学找到一个或者一组模型的组合使得它和真实的情况非常接近。之所以被称为数据驱动方法是因为它是现有大量的数据,而不是预设模型,然后用很多简单的模型来契合数据。
    虽然通过这种方法找到的模型可能和真实模型存在一定的偏差,但是在误差允许的范围内,单从结果上和精确的模型是等效的。可以看出来数据驱动的方法目标就是近似替代,它甚至不是为了追求真实,仅仅是为了能够说明问题。
     

    图像分类流程:

      图像分类:输入一个元素为像素值的数组,然后给它分配一个分类标签。完整流程如下:

    • 输入:输入是包含N个图像的集合,每个图像的标签是K种分类标签中的一种。这个集合称为训练集。
    • 学习:这一步的任务是使用训练集来学习每个类到底长什么样。一般该步骤叫做训练分类器或者学习一个模型
    • 评价:让分类器来预测它未曾见过的图像的分类标签,并以此来评价分类器的质量。我们会把分类器预测的标签和图像真正的分类标签对比。毫无疑问,分类器预测的分类标签和图像真正的分类标签如果一致,那就是好事,这样的情况越多越好。

     

    Nearest Neighbor Classifier

    最简单的分类器......

    最近邻算法:在训练机器的过程中,我们什么也不做,我们只是单纯记录所有的训练数据,在图片预测的步骤,我们会拿一些新的图片去在训练数据中寻找与新图片最相似的,然后基于此,来给出一个标签。

    例:用最近邻算法用于这数据集中的图片,在训练集中找到最接近的样本:图像分类数据集:CIFAR-10

       这个数据集包含了60000张32X32的小图像。每张图像都有10种分类标签中的一种。这60000张图像被分为包含50000张图像的训练集和包含10000张图像的测试集。

      我们需要的是,输入一张测试图片,从训练集中找出与它L1距离最近的一张训练图,将其所对应的分类标签作为答案,也就是作为测试图片的分类标签。

      下面,让我们看看如何用代码来实现这个分类器。首先,我们将CIFAR-10的数据加载到内存中,并分成4个数组:训练数据和标签,测试数据和标签。在下面的代码中,Xtr(大小是50000x32x32x3)存有训练集中所有的图像,Ytr是对应的长度为50000的1维数组,存有图像对应的分类标签(从0到9):

    Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar10/') # a magic function we provide
    # flatten out all images to be one-dimensional
    Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072
    Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072

      现在我们得到所有的图像数据,并且把他们拉长成为行向量了。接下来展示如何训练并评价一个分类器:

    nn = NearestNeighbor() # create a Nearest Neighbor classifier class
    nn.train(Xtr_rows, Ytr) # train the classifier on the training images and labels
    Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
    # and now print the classification accuracy, which is the average number
    # of examples that are correctly predicted (i.e. label matches)
    print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )

      作为评价标准,我们常常使用准确率,它描述了我们预测正确的得分。请注意以后我们实现的所有分类器都需要有这个API:train(X, y)函数。该函数使用训练集的数据和标签来进行训练。从其内部来看,类应该实现一些关于标签和标签如何被预测的模型。这里还有个predict(X)函数,它的作用是预测输入的新数据的分类标签。

    L1距离(即两尺寸相同图对应位置间差异的和):

    完整代码:

    Python3 实现:【使用L1距离的Nearest Neighbor分类器】

    【我是将cifar-10-batches-py数据放在代码同一目录下,代码中具体文件位置请根据情况自行设置】

     1 import numpy as np
     2 import pickle
     3 import os
     4 
     5 
     6 class NearestNeighbor(object):
     7     def __init__(self):
     8         pass
     9 
    10     def train(self, X, y):
    11         """ X is N x D where each row is an example. Y is 1-dimension of size N """
    12         # the nearest neighbor classifier simply remembers all the training data
    13         self.Xtr = X
    14         self.ytr = y
    15 
    16     def predict(self, X):
    17         """ X is N x D where each row is an example we wish to predict label for """
    18         num_test = X.shape[0]
    19         # lets make sure that the output type matches the input type
    20         Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
    21 
    22         # loop over all test rows
    23         for i in range(num_test):
    24             # find the nearest training image to the i'th test image
    25             # using the L1 distance (sum of absolute value differences)
    26             distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)
    27             min_index = np.argmin(distances)  # get the index with smallest distance
    28             Ypred[i] = self.ytr[min_index]  # predict the label of the nearest example
    29 
    30         return Ypred
    31 
    32 
    33 def load_CIFAR_batch(file):
    34     """ load single batch of cifar """
    35     with open(file, 'rb') as f:
    36         datadict = pickle.load(f, encoding='latin1')
    37         X = datadict['data']
    38         Y = datadict['labels']
    39         X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
    40         Y = np.array(Y)
    41     return X, Y
    42 
    43 
    44 def load_CIFAR10(ROOT):
    45     """ load all of cifar """
    46     xs = []
    47     ys = []
    48     for b in range(1,6):
    49         f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    50         X, Y = load_CIFAR_batch(f)
    51         xs.append(X)
    52         ys.append(Y)
    53     Xtr = np.concatenate(xs)  # 使变成行向量
    54     Ytr = np.concatenate(ys)
    55     del X, Y
    56     Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    57     return Xtr, Ytr, Xte, Yte
    58 
    59 
    60 Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar-10-batches-py/')  # a magic function we provide
    61 # flatten out all images to be one-dimensional
    62 Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)  # Xtr_rows becomes 50000 x 3072
    63 Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)  # Xte_rows becomes 10000 x 3072
    64 
    65 
    66 nn = NearestNeighbor()  # create a Nearest Neighbor classifier class
    67 nn.train(Xtr_rows, Ytr)  # train the classifier on the training images and labels
    68 Yte_predict = nn.predict(Xte_rows)  # predict labels on the test images
    69 # and now print the classification accuracy, which is the average number
    70 # of examples that are correctly predicted (i.e. label matches)
    71 print('accuracy: %f' % (np.mean(Yte_predict == Yte)))
    View Code

    然后我运行了很久......几个小时?

    运行结果:accuracy: 0.385900

     若用L2距离(欧式距离):

    修改上面代码1行就可以:

    distances = np.sqrt(np.sum(np.square(self.Xtr - X[i,:]), axis = 1))

    关于简单分类器的问题:

    • 如果我们有N个实例,训练和测试的过程可以有多快?
    • Train:O(1),predict:O(n)

      这是糟糕的。训练是连续的过程,因为我们并不需要做任何事情, 我们只需存储数据。但在测试时,我们需要停下来,将数据集中N个训练实例 与我们的测试图像进行比较,这是个很慢的过程。

      但是在实际使用中,我们希望训练过程比较慢,而测试过程快。因为,训练过程是在数据中心中完成的,它可以负担起非常大的运算量,从而训练出一个优秀的分类器。 然而,当你在测试过程部署分类器时,你希望它运行在手机上、浏览器、或其他低功耗设备,但你又希望分类器能够快速地运行,由此看来,最近邻算法有点落后了。。。

    参考:

    https://www.bilibili.com/video/av17204303/?from=search&seid=6625954842411789830

     https://zhuanlan.zhihu.com/p/20894041?refer=intelligentunit

    https://blog.csdn.net/dawningblue/article/details/75119639

    https://www.cnblogs.com/hans209/p/6919851.html

  • 相关阅读:
    smarty-2014-02-28
    PHP Functions
    Zabbix自定义监控网站服务是否能够正常响应
    Zabbix自定义监控网站服务是否能够正常响应
    shell技巧
    shell技巧
    ansible安装配置zabbix客户端
    ansible安装配置zabbix客户端
    shell命令getopts
    shell命令getopts
  • 原文地址:https://www.cnblogs.com/GraceSkyer/p/8735908.html
Copyright © 2011-2022 走看看