zoukankan      html  css  js  c++  java
  • 【机器学习】k-means聚类

    参考博客:

    1.https://blog.csdn.net/Dhane/article/details/86661208

    2.https://www.cnblogs.com/txx120/p/11487674.html

    以及

    3.微信公众号:《科普:目标检测Anchor是什么?怎么科学设置?[附代码]》

     https://mp.weixin.qq.com/s?__biz=MzIwODI2NDkxNQ==&mid=2247487906&idx=3&sn=c3d656c9f34a95a22eae2b80564e071e&chksm=97049a1ea07313085df25b4a12763eedcf316a2d91891ddb7b74688bfc12407d9f325fcb7059&mpshare=1&scene=1&srcid=&sharer_sharetime=1584099361239&sharer_shareid=fee749290e0252f29857a970d866b498&key=f91b344e81f23c9a5ffce80c3a4ac39cf41c48b36e88e7c037b1a9c42577f9b0b16c9333dee6f7bdfb3ebaa6050762728d774362ef7f7bc25de0530fffd8c49871f9817cd12d29c88ce1a48374d7e78e&ascene=1&uin=MTgxMTE1NDQ4MQ%3D%3D&devicetype=Windows+7&version=62080085&lang=zh_CN&exportkey=AaTcIKbBi33fg2x0jeWOTCY%3D&pass_ticket=uJlkE2fEdlAcbaeQDt9nXACAFXx54NEfh8lvIQ8dmMlc3BeiFiY%2FW%2BSUcP4zINbm

     

    源码链接:使用K-means聚类合理设置anchor

    https://github.com/AIZOOTech/object-detection-anchors

    以及

    4.【白话机器学习】算法理论+实战之K-Means聚类算法

    https://mp.weixin.qq.com/s?__biz=MzIwODI2NDkxNQ==&mid=2247487944&idx=1&sn=0f768892decc3abfe86c4c2eeddc285f&chksm=97049a74a07313621aff02dac9a434f9465679df412da3949dd386f5a33944229b1ddf34950e&mpshare=1&scene=1&srcid=&sharer_sharetime=1584752492656&sharer_shareid=fee749290e0252f29857a970d866b498&key=feb1b2e52934a4d2813d56055eb27ca5fca68769c79ea411942d8ae10a381aa80295bee433d67968682219858db88a26771a7fab54c9dbe491aea4dc2eb2ffc0a86b03263853c016aa2f91d892571a31&ascene=1&uin=MTgxMTE1NDQ4MQ%3D%3D&devicetype=Windows+7&version=62080085&lang=zh_CN&exportkey=AdRA%2BiMkDLRG4RI%2F%2FtH4cj0%3D&pass_ticket=4eUa8Zb6iZa582C5psMmi3zCOn0fblLm6c01JYyFvRdF3wy6TlxzhjTmM4agUhWb 

    为3中的源码添加了一些注释,kmeans.py

      1 import numpy as np
      2 
      3 
      4 def iou(box, clusters):
      5     """
      6     Calculates the Intersection over Union (IoU) between a box and k clusters.
      7     :param box: tuple or array, shifted to the origin (i. e. width and height)
      8     :param clusters: numpy array of shape (k, 2) where k is the number of clusters
      9     :return: numpy array of shape (k, 0) where k is the number of clusters
     10     """
     11     x = np.minimum(clusters[:, 0], box[0])#取width最小值
     12     y = np.minimum(clusters[:, 1], box[1])#取height最小值
     13     if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:
     14     #先判断x是否为0,是返回True,否则返回False,然后用np.count_nonzero()返回非零个数,如果非零个数>0,说明box里的宽或高有零值,则触发异常
     15         raise ValueError("Box has no area")
     16 
     17     intersection = x * y#最小的宽高相乘得到交集面积
     18     box_area = box[0] * box[1]#当前框面积
     19     cluster_area = clusters[:, 0] * clusters[:, 1]#随机抽取的25个框的面积
     20 
     21     iou_ = intersection / (box_area + cluster_area - intersection)
     22 
     23     return iou_
     24 
     25 
     26 def avg_iou(boxes, clusters):
     27     """
     28     Calculates the average Intersection over Union (IoU) between a numpy array of boxes and k clusters.
     29     :param boxes: numpy array of shape (r, 2), where r is the number of rows
     30     :param clusters: numpy array of shape (k, 2) where k is the number of clusters
     31     :return: average IoU as a single float
     32             返回:每个框与所有聚类中心点的iou取最大值,将这些最大值相加再取均值
     33     """
     34     return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])
     35 
     36 
     37 def translate_boxes(boxes):
     38     """
     39     Translates all the boxes to the origin.
     40     :param boxes: numpy array of shape (r, 4)
     41     :return: numpy array of shape (r, 2)
     42     """
     43     new_boxes = boxes.copy()
     44     for row in range(new_boxes.shape[0]):
     45         new_boxes[row][2] = np.abs(new_boxes[row][2] - new_boxes[row][0])
     46         new_boxes[row][3] = np.abs(new_boxes[row][3] - new_boxes[row][1])
     47     return np.delete(new_boxes, [0, 1], axis=1)
     48 
     49 
     50 def kmeans(boxes, k, dist=np.median):
     51     """
     52     Calculates k-means clustering with the Intersection over Union (IoU) metric.
     53     :param boxes: numpy array of shape (r, 2), where r is the number of rows
     54     :param k: number of clusters
     55     :param dist: distance function
     56     :return: numpy array of shape (k, 2)
     57     """
     58     rows = boxes.shape[0]
     59 
     60     distances = np.empty((rows, k))#返回(rows,k)形状的空数组
     61     last_clusters = np.zeros((rows,))#返回(rows,)的全零数组
     62 
     63     np.random.seed()#随机生成种子数
     64 
     65     # the Forgy method will fail if the whole array contains the same rows
     66     clusters = boxes[np.random.choice(rows, k, replace=False)]#在rows中随机抽取数字组成(k,)的一维数组,作为k个聚类中心,不能取重复数字
     67     print("clusters id {}".format(clusters))
     68 
     69     iter_num = 1
     70     while True:
     71         print("Iteration: %d" % iter_num)
     72         iter_num += 1
     73 
     74         for row in range(rows):
     75             distances[row] = 1 - iou(boxes[row], clusters)
     76             #计算第row个box与随机抽取的25个box的iou,用此公式计算第row个box与随机抽取的25个box之间的距离
     77         print('{}'.format(distances.shape))#(144027, 25)
     78 
     79         nearest_clusters = np.argmin(distances, axis=1)#按行取最小值索引,每一个框属于第几个聚类中心
     80         print('nearest_clusters',nearest_clusters)
     81         print('{}'.format(type(nearest_clusters)))#(144027,)
     82 
     83         if (last_clusters == nearest_clusters).all():#所有的返回值都为True才会执行,即当每个框属于某个聚类中心的索引不再更新时跳出循环
     84             break
     85 
     86         for cluster in range(k):
     87             print('len(boxes[nearest_clusters == cluster]):{}'.format(len(boxes[nearest_clusters == cluster])))#返回True的数量
     88             #print('boxes[nearest_clusters == cluster]:{}'.format(boxes[nearest_clusters == cluster]))
     89             #print('(nearest_clusters == cluster):{}'.format(nearest_clusters == cluster))
     90             #[False False False ...  True  True False]
     91             if len(boxes[nearest_clusters == cluster]) == 0:#
     92                 print("Cluster %d is zero size" % cluster)
     93                 # to avoid empty cluster
     94                 clusters[cluster] = boxes[np.random.choice(rows, 1, replace=False)]#此聚类中心size为0时重新为当前位置随机选择一个聚类中心
     95                 continue
     96 
     97             clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)#dist=np.median,在列的方向上求中位数
     98             #clusters[cluster] = np.median(boxes[nearest_clusters == cluster], axis=0)
     99             print('clusters[cluster]:{}'.format(clusters[cluster]))#[0.015625   0.02635432]
    100             #print('clusters[cluster]:{}'.format(clusters[cluster]))
    101 
    102         last_clusters = nearest_clusters
    103         #返回的是每一个聚类中心重新计算中位数,反复迭代计算后的新聚类中心点
    104 
    105     return clusters

    examples.py

      1 import glob
      2 import xml.etree.ElementTree as ET
      3 
      4 import numpy as np
      5 import matplotlib.pyplot as plt
      6 from kmeans import kmeans, avg_iou
      7 
      8 # ANNOTATIONS_PATH = "./data/pascalvoc07-annotations"
      9 ANNOTATIONS_PATH = "./data/widerface-annotations"
     10 CLUSTERS = 25
     11 BBOX_NORMALIZE = False
     12 
     13 def show_cluster(data, cluster, max_points=2000):
     14     '''
     15     Display bouding box's size distribution and anchor generated in scatter.散点图
     16     '''
     17     if len(data) > max_points:
     18         idx = np.random.choice(len(data), max_points)
     19         data = data[idx]#在所有的data中随机抽取max_points个数据
     20     plt.scatter(data[:,0], data[:,1], s=5, c='lavender')#输入数据是data的宽和高,s=5是点的大小,c是颜色
     21     plt.scatter(cluster[:,0], cluster[:, 1], c='red', s=100, marker="^")#‘^’是正三角形
     22     plt.xlabel("Width")
     23     plt.ylabel("Height")
     24     plt.title("Bounding and anchor distribution")
     25     plt.savefig("cluster.png")
     26     plt.show()
     27 
     28 def show_width_height(data, cluster, bins=50):
     29     '''
     30     Display bouding box distribution with histgram.直方图
     31     '''
     32     if data.dtype != np.float32:
     33         data = data.astype(np.float32)
     34     width = data[:, 0]
     35     print('width_in show_width_height)',len(width))
     36     height = data[:, 1]
     37     print('height in show_width_height',height)
     38     ratio = height / width
     39 
     40     plt.figure(1,figsize=(20, 6))#num:图像编号或名称,数字为编号 ,字符串为名称;figsize:指定figure的宽和高,单位为英寸;
     41     plt.subplot(131)
     42     #subplot可以规划figure划分为n个子图,但每条subplot命令只会创建一个子图,131表示整个figure分成1行3列,共3个子图,这里子图在第一行第一列
     43     plt.hist(width, bins=bins, color='green')
     44     #width指定每个bin(箱子)分布的数据,对应x轴;bins这个参数指定bin(箱子)的个数,也就是总共有几条条状图;color指定条状图的颜色;默认y轴是个数
     45     plt.xlabel('width')
     46     plt.ylabel('number')
     47     plt.title('Distribution of Width')
     48 
     49     plt.subplot(132)
     50     plt.hist(height,bins=bins, color='blue')
     51     plt.xlabel('Height')
     52     plt.ylabel('Number')
     53     plt.title('Distribution of Height')
     54 
     55     plt.subplot(133)
     56     plt.hist(ratio, bins=bins,  color='magenta')
     57     plt.xlabel('Height / Width')
     58     plt.ylabel('number')
     59     plt.title('Distribution of aspect ratio(Height / Width)')
     60     plt.savefig("shape-distribution.png")
     61     plt.show()
     62     
     63 
     64 def sort_cluster(cluster):
     65     '''
     66     Sort the cluster to with area small to big.
     67     '''
     68     if cluster.dtype != np.float32:
     69         cluster = cluster.astype(np.float32)
     70     print('cluster',cluster)
     71     area = cluster[:, 0] * cluster[:, 1]#计算每一个聚类中心点横纵坐标的乘积
     72     cluster = cluster[area.argsort()]#argsort函数返回的是数组值从小到大的索引值,此处将cluster按从小到大进行排序
     73     print('sorted cluster',cluster)
     74     ratio = cluster[:,1:2] / cluster[:, 0:1]
     75     print('ratio',ratio)
     76     return np.concatenate([cluster, ratio], axis=-1)  # 按轴axis连接array组成一个新的array,-1表示在最后一维进行合并,也就是行的方向合并
     77 
     78 
     79 def load_dataset(path, normalized=True):
     80     '''
     81     load dataset from pasvoc formatl xml files
     82     '''
     83     dataset = []
     84     for xml_file in glob.glob("{}/*xml".format(path)):#获取path路径下所有的xml文件并返回一个list
     85         tree = ET.parse(xml_file)#调用parse()方法,返回解析树
     86 
     87         height = int(tree.findtext("./size/height"))
     88         width = int(tree.findtext("./size/width"))
     89 
     90         for obj in tree.iter("object"):
     91             if normalized:
     92                 xmin = int(obj.findtext("bndbox/xmin")) / float(width)
     93                 ymin = int(obj.findtext("bndbox/ymin")) / float(height)
     94                 xmax = int(obj.findtext("bndbox/xmax")) / float(width)
     95                 ymax = int(obj.findtext("bndbox/ymax")) / float(height)
     96             else:
     97                 xmin = int(obj.findtext("bndbox/xmin")) 
     98                 ymin = int(obj.findtext("bndbox/ymin")) 
     99                 xmax = int(obj.findtext("bndbox/xmax")) 
    100                 ymax = int(obj.findtext("bndbox/ymax"))
    101             if (xmax - xmin) == 0 or (ymax - ymin) == 0:
    102                 continue # to avoid divded by zero error.
    103             dataset.append([xmax - xmin, ymax - ymin])
    104 
    105     return np.array(dataset)
    106 
    107 print("Start to load data annotations on: %s" % ANNOTATIONS_PATH)
    108 data = load_dataset(ANNOTATIONS_PATH, normalized=BBOX_NORMALIZE)
    109 print('{}'.format(type(data)))#<class 'numpy.ndarray'>,(144027, 2)
    110 print("Start to do kmeans, please wait for a moment.")
    111 out = kmeans(data, k=CLUSTERS)#out为由kmeans找到的聚类中心点
    112 
    113 out_sorted = sort_cluster(out)
    114 print('out_sorted',out_sorted)
    115 print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))#每个框与聚类中心点的最大IOU的平均值,可以用来表示所有框与聚类中心点的平均相似度
    116 
    117 show_cluster(data, out, max_points=2000)
    118 
    119 if out.dtype != np.float32:
    120     out = out.astype(np.float32)
    121 
    122 print("Recommanded aspect ratios(width/height)")
    123 print("Width    Height   Height/Width")
    124 for i in range(len(out_sorted)):
    125     print("%.3f      %.3f     %.1f" % (out_sorted[i,0], out_sorted[i,1], out_sorted[i,2]))
    126 show_width_height(data, out, bins=50)
  • 相关阅读:
    MySQL根据出生日期计算年龄的五种方法比较
    用于测试API并生文档的开发人员工具
    【实例】使用Eolinker工具进行接口测试时传递集合参数的方法
    如何克服API测试中的挑战
    关于API网关(一)性能
    为什么需要监控API
    比Swagger2更好用的自动生成文档工具?对比流程说话!
    如何通过3个步骤执行基本的API测试
    【学习】API接口测试用例编写规则
    微信小程序之蓝牙 BLE 踩坑记录
  • 原文地址:https://www.cnblogs.com/DJames23/p/12494164.html
Copyright © 2011-2022 走看看