zoukankan      html  css  js  c++  java
  • Python图像识别(聚类)

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Fri Sep 21 15:37:26 2018
     4 
     5 @author: zhen
     6 """
     7 from PIL import Image
     8 import numpy as np
     9 from sklearn.cluster import KMeans
    10 import matplotlib
    11 import matplotlib.pyplot as plt
    12 
    13 def restore_image(cb, cluster, shape):
    14     row, col, dummy = shape
    15     image = np.empty((row, col, dummy))
    16     for r in range(row):
    17         for c in range(col):
    18             image[r, c] = cb[cluster[r * col + c]]
    19     return image
    20 
    21 def show_scatter(a):
    22     N = 10
    23     density, edges = np.histogramdd(a, bins=[N, N, N], range=[(0, 1), (0, 1), (0, 1)])
    24     density /= density.max()
    25     x = y = z = np.arange(N)
    26     d = np.meshgrid(x, y, z)
    27     
    28     fig = plt.figure(1, facecolor='w')
    29     ax = fig.add_subplot(111, projection='3d')
    30 
    31     cm = matplotlib.colors.ListedColormap(list('rgbm'))
    32     ax.scatter(d[0], d[1], d[2], s=100 * density, cmap=cm, marker='o', depthshade=True)
    33     ax.set_xlabel(u'')
    34     ax.set_ylabel(u'绿')
    35     ax.set_zlabel(u'')
    36     plt.title(u'图像颜色三维频数分布', fontsize=20)
    37     
    38     plt.figure(2, facecolor='w')
    39     den = density[density > 0]
    40     den = np.sort(den)[::-1]
    41     t = np.arange(len(den))
    42     plt.plot(t, den, 'r-', t, den, 'go', lw=2)
    43     plt.title(u'图像颜色频数分布', fontsize=18)
    44     plt.grid(True)
    45     
    46     plt.show()
    47       
    48 if __name__ == '__main__':
    49     matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
    50     matplotlib.rcParams['axes.unicode_minus'] = False
    51     # 聚类数2,6,30
    52     num_vq = 2
    53     im = Image.open('C:/Users/zhen/.spyder-py3/images/Lena.png')
    54     image = np.array(im).astype(np.float) / 255
    55     image = image[:, :, :3]
    56     image_v = image.reshape((-1, 3))
    57     kmeans = KMeans(n_clusters=num_vq, init='k-means++')
    58     show_scatter(image_v)
    59     
    60     N = image_v.shape[0]  # 图像像素总数
    61     # 选择样本,计算聚类中心
    62     idx = np.random.randint(0, N, size=int(N * 0.7))
    63     image_sample = image_v[idx]
    64     kmeans.fit(image_sample)
    65     result = kmeans.predict(image_v)  # 聚类结果
    66     print('聚类结果:
    ', result)
    67     print('聚类中心:
    ', kmeans.cluster_centers_)
    68     
    69     plt.figure(figsize=(15, 8), facecolor='w')
    70     plt.subplot(211)
    71     plt.axis('off')
    72     plt.title(u'原始图片', fontsize=18)
    73     plt.imshow(image)
    74     # plt.savefig('原始图片.png')
    75     
    76     plt.subplot(212)
    77     vq_image = restore_image(kmeans.cluster_centers_, result, image.shape)
    78     plt.axis('off')
    79     plt.title(u'聚类个数:%d' % num_vq, fontsize=20)
    80     plt.imshow(vq_image)
    81     # plt.savefig('矢量化图片.png')
    82     
    83     plt.tight_layout(1.2)
    84     plt.show()

    结果:

          

      1.当k=2时:

      

           

      2.当k=6时:

        

            

      3.当k=30时:

        

           

    总结:当聚类个数较少时,算法运算速度快但效果较差,当聚类个数较多时,运算速度慢效果好但容易过拟合,所以恰当的k值对于聚类来说影响极其明显!!

  • 相关阅读:
    python中type、object与class之间关系(一切皆对象)
    为什么在python中推荐使用多进程而不是多线程(转载)
    CPU密集型 VS IO密集型
    多CPU,多核,多进程,多线程
    Mac下brew安装与配置mysql
    mac安装navicat mysql破解版
    微信公众号-h5调用微信支付
    为什么js中0.1+0.2不等于0.3,怎样处理使之相等?(转载)
    gitlab安装和汉化
    PyPI使用国内源
  • 原文地址:https://www.cnblogs.com/yszd/p/9687897.html
Copyright © 2011-2022 走看看