zoukankan      html  css  js  c++  java
  • YOLOv3中K-Means聚类出新数据集的Anchor尺寸

    参考博客:

    聚类kmeans算法在yolov3中的应用 https://www.cnblogs.com/sdu20112013/p/10937717.html

    这篇博客写得非常详细,也贴出了github代码:https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py

    整体代码如下:

      1 '''
      2 Created on Feb 20, 2017
      3 @author: jumabek
      4 '''
      5 from os import listdir
      6 from os.path import isfile, join
      7 import argparse
      8 # import cv2
      9 import numpy as np
     10 import sys
     11 import os
     12 import shutil
     13 import random
     14 import math
     15 width_in_cfg_file = 416.
     16 height_in_cfg_file = 416.
     17 
     18 
     19 def IOU(x, centroids):
     20     '''
     21     :param x: 当前gt的w和h
     22     :param centroids: 质心
     23     :return:当前gt与每个质心的相似度,np.array形式
     24     '''
     25     similarities = []
     26     k = len(centroids)
     27     for centroid in centroids:
     28         c_w, c_h = centroid
     29         w, h = x
     30         if c_w >= w and c_h >= h:
     31             similarity = w * h / (c_w * c_h)
     32         elif c_w >= w and c_h <= h:
     33             similarity = w * c_h / (w * h + (c_w - w) * c_h)  # 交叉面积/总面积
     34         elif c_w <= w and c_h >= h:
     35             similarity = c_w * h / (w * h + c_w * (c_h - h))
     36         else:  # means both w,h are bigger than c_w and c_h respectively
     37             similarity = (c_w * c_h) / (w * h)
     38         similarities.append(similarity)  # will become (k,) shape
     39     return np.array(similarities)
     40 
     41 
     42 def avg_IOU(X, centroids):
     43     n, d = X.shape
     44     sum = 0.
     45     for i in range(X.shape[0]):
     46         # note IOU() will return array which contains IoU for each centroid and X[i] // slightly ineffective, but I am too lazy
     47         sum += max(IOU(X[i], centroids))
     48     return sum / n
     49 
     50 
     51 def write_anchors_to_file(centroids, X, anchor_file):
     52     f = open(anchor_file, 'w')
     53     anchors = centroids.copy()
     54     print(anchors.shape)
     55     for i in range(anchors.shape[0]):
     56         anchors[i][0] *= width_in_cfg_file  # / 32. YOLOv3不用除以32 # 归一化后的宽高乘以预设的图片宽高
     57         anchors[i][1] *= height_in_cfg_file  # / 32.
     58     widths = anchors[:, 0]
     59     sorted_indices = np.argsort(widths)
     60     print('Anchors = ', anchors[sorted_indices])
     61     for i in sorted_indices[:-1]:
     62         # 将前n-1个anchor写入txt
     63         f.write('%0.2f,%0.2f, ' % (anchors[i, 0], anchors[i, 1]))
     64     # there should not be comma after last anchor, that's why
     65     # 最后一个anchor写完以后需要换行,所以单独填写
     66     f.write('%0.2f,%0.2f
    ' % (anchors[sorted_indices[-1:], 0], anchors[sorted_indices[-1:], 1]))
     67     f.write('%f
    ' % (avg_IOU(X, centroids)))
     68     print()
     69 
     70 
     71 def kmeans(X, centroids, eps, anchor_file):
     72     '''
     73 
     74     :param X: annotation_dims,所有的标注信息中的宽和高
     75     :param centroids: 随机生成的质心
     76     :param eps:
     77     :param anchor_file: 保存结果的文件
     78     :return:
     79     '''
     80     N = X.shape[0]
     81     iterations = 0
     82     k, dim = centroids.shape
     83     prev_assignments = np.ones(N) * (-1)
     84     iter = 0
     85     old_D = np.zeros((N, k))
     86     while True:
     87         D = []
     88         iter += 1
     89         for i in range(N):
     90             # 计算gt框与质心之间的距离,相似度越大,说明当前gt越接近于质心,此距离就应该越小
     91             d = 1 - IOU(X[i], centroids)
     92             D.append(d)
     93         D = np.array(D)  # D.shape = (N,k)
     94         print("iter {}: dists = {}".format(iter, np.sum(np.abs(old_D - D))))
     95         # assign samples to centroids
     96         assignments = np.argmin(D, axis=1)  # 返回每一行的最小值的下标.即当前样本应该归为k个质心中的哪一个质心.
     97         if (assignments == prev_assignments).all():  # 质心已经不再变化
     98             print("Centroids = ", centroids)
     99             write_anchors_to_file(centroids, X, anchor_file)
    100             return
    101         # calculate new centroids,更新质心
    102         centroid_sums = np.zeros((k, dim), np.float)
    103         for i in range(N):
    104             centroid_sums[assignments[i]] += X[i]
    105         for j in range(k):
    106             centroids[j] = centroid_sums[j] / (np.sum(assignments == j))
    107         prev_assignments = assignments.copy()
    108         old_D = D.copy()
    109 
    110 
    111 def main(argv):
    112     parser = argparse.ArgumentParser()
    113     parser.add_argument('-filelist', default='F://BaiduNetdiskDownload//trainall_name.txt',
    114                         help='path to filelist
    ')
    115     parser.add_argument('-output_dir', default='F://BaiduNetdiskDownload//generated_anchors//anchors//', type=str,
    116                         help='Output anchor directory
    ')
    117     parser.add_argument('-num_clusters', default=6, type=int,
    118                         help='number of clusters
    ')
    119     args = parser.parse_args()
    120     if not os.path.exists(args.output_dir):
    121         os.mkdir(args.output_dir)
    122     f = open(args.filelist)
    123     lines = [line.rstrip('
    ') for line in f.readlines()]
    124     annotation_dims = []
    125     size = np.zeros((1, 1, 3))
    126     for line in lines:
    127         # 注意路径问题,通过替换图片路径中的Images为labels来找到标签信息
    128         line = line.replace('Images','labels')
    129         # line = line.replace('img1','labels')
    130         # line = line.replace('JPEGImages', 'labels')
    131         line = line.replace('.jpg', '.txt')
    132         line = line.replace('.png', '.txt')
    133         print(line)
    134 
    135         f2 = open(line)
    136         for line in f2.readlines():
    137             line = line.rstrip('
    ')
    138             w, h = line.split(' ')[3:]  # 得到标注文件的宽和高[0 0.83984 0.40700 0.17188 0.47218]
    139             # print(w,h)
    140             annotation_dims.append(tuple(map(float, (w, h))))
    141     annotation_dims = np.array(annotation_dims)
    142     eps = 0.005
    143     if args.num_clusters == 0:
    144         for num_clusters in range(1, 11):  # we make 1 through 10 clusters
    145             anchor_file = join(args.output_dir, 'anchors%d.txt' % (num_clusters))
    146             indices = [random.randrange(annotation_dims.shape[0]) for i in range(num_clusters)]
    147             centroids = annotation_dims[indices]
    148             kmeans(annotation_dims, centroids, eps, anchor_file)
    149             print('centroids.shape', centroids.shape)
    150     else:
    151         anchor_file = join(args.output_dir, 'anchors%d.txt' % (args.num_clusters))  # 保存结果的文件
    152         # 在所有labels数量范围内随机生成质心的索引数,生成num_clusters个
    153         indices = [random.randrange(annotation_dims.shape[0]) for i in range(args.num_clusters)]
    154         # 生成质心
    155         centroids = annotation_dims[indices]
    156         # 调用kmeans
    157         kmeans(annotation_dims, centroids, eps, anchor_file)
    158         print('centroids.shape', centroids.shape)
    159 
    160 
    161 if __name__ == "__main__":
    162     main(sys.argv)

    使用生成YOLOv3 anchor时需要注意

    anchors[i][0] *= width_in_cfg_file  # / 32. YOLOv3不用除以32 # 归一化后的宽高乘以预设的图片宽高

    最后生成的结果,6个anchors:

    7.90,21.48, 18.72,61.61, 34.67,138.55, 65.49,251.30, 104.70,64.11, 144.33,434.60
    0.582349

    可以看出宽高比都为1:3左右,结合我使用的是行人检测的数据集,这个比例还算正常。但第5组数据(104.70,64.11)不符合这个宽高比

  • 相关阅读:
    web前端之jQuery
    java之awt编程
    java连接数据库的基本操作
    实习生应聘经历2018/3/1
    javaweb学习之建立简单网站
    mysql之视图
    71. Simplify Path
    347. Top K Frequent Elements
    7. Reverse Integer
    26. Remove Duplicates from Sorted Array
  • 原文地址:https://www.cnblogs.com/DJames23/p/13416722.html
Copyright © 2011-2022 走看看