zoukankan      html  css  js  c++  java
  • k-means算法求解anchors (针对YOLO3)

    文字内容以后再补充:

    import numpy as np

    # 定义Box类,描述bounding box的坐标
    class Box():
    def __init__(self, x, y, w, h):
    self.x = x
    self.y = y
    self.w = w
    self.h = h

    def box_iou(a, b):
    '''
    # a和b都是Box类型实例
    # 返回值area是box a 和box b 的交集面积
    '''
    a_x1 = a.x-a.w/2
    a_y1 = a.y - a.h / 2
    a_x2 = a.x+a.w/2
    a_y2 = a.y + a.h / 2
    b_x1 = b.x-b.w/2
    b_y1 = b.y - b.h / 2
    b_x2 = b.x+b.w/2
    b_y2 = b.y + b.h / 2
    box_x1 = max(a_x1,b_x1)
    box_y1 = max(a_y1, b_y1)
    box_x2 = min(a_x2,b_x2)
    box_y2 = min(a_y2, b_y2)
    box_w = box_x2-box_x1
    box_h = box_y2 - box_y1
    if box_w < 0 or box_h < 0:
    area = 0
    else:
    area = box_w * box_h
    box_intersection=area
    box_union = a.w * a.h + b.w * b.h-box_intersection
    iou = box_intersection/box_union
    return iou

    # 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响
    def init_centroids(boxes, n_anchors):
    '''
    随机选择一个box作为
    :param boxes: 是所有bounding boxes的Box对象列表
    :param n_anchors: n_anchors是k-means的k值
    :return: 返回值centroids 是初始化的n_anchors个centroid
    '''
    centroids = []
    boxes_num = len(boxes)
    centroid_index = np.random.choice(boxes_num, 1) # 在boxes_num=55 中产生一个数23
    centroids.append(boxes[centroid_index])
    print(centroids[0].w, centroids[0].h)

    for centroid_index in range(0, n_anchors-1):
    sum_distance = 0
    distance_list = []
    cur_sum = 0
    for box in boxes:
    min_distance = 1
    for centroid_i, centroid in enumerate(centroids):
    distance = (1 - box_iou(box, centroid))
    if distance < min_distance:
    min_distance = distance
    sum_distance += min_distance
    distance_list.append(min_distance)
    distance_thresh = sum_distance*np.random.random()

    for i in range(0, boxes_num):
    cur_sum += distance_list[i]
    if cur_sum > distance_thresh:
    centroids.append(boxes[i])
    print(boxes[i].w, boxes[i].h)
    break
    return centroids

    # 进行 k-means 计算新的centroids
    def do_kmeans(n_anchors, boxes, centroids):
    '''
    :param n_anchors: 是k-means的k值
    :param boxes: 是所有bounding boxes的Box对象列表
    :param centroids: 是所有簇的中心
    :return: # 返回值new_centroids 是计算出的新簇中心
    # 返回值groups是n_anchors个簇包含的boxes的列表
    # 返回值loss是所有box距离所属的最近的centroid的距离的和
    '''
    loss = 0
    groups = []
    new_centroids = []
    for i in range(n_anchors):
    groups.append([]) # [[], [], [], []]
    new_centroids.append(Box(0, 0, 0, 0))
    # 以上代码建立初始化
    for box in boxes:
    min_distance = 1
    group_index = 0
    for centroid_index, centroid in enumerate(centroids):
    # 这个循环实际是在找box与哪个centroidsiou最小,最接近
    distance = (1 - box_iou(box, centroid))
    if distance < min_distance:
    min_distance = distance
    group_index = centroid_index
    groups[group_index].append(box) # 将其保留对应的族中
    loss += min_distance
    new_centroids[group_index].w += box.w # 累加对应的族中的w
    new_centroids[group_index].h += box.h

    for i in range(n_anchors): # 得到新的族中的w与h
    new_centroids[i].w /= len(groups[i])
    new_centroids[i].h /= len(groups[i])

    return new_centroids, groups, loss

    def init_all_value(use_init_centroids=1, n_anchors=9):
    # 构建初始化族中心
    if use_init_centroids:
    centroids = init_centroids(boxes, n_anchors)
    else:
    centroid_indices = np.random.choice(len(boxes), n_anchors)
    centroids = []
    for centroid_index in centroid_indices:
    centroids.append(boxes[centroid_index])
    # 构建初始化 groups 保存对应族的box类
    centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)

    return centroids, groups, old_loss

    if __name__=='__main__':
    # 构建boxes

    boxes=[]
    boxes.append(Box(4,5,6,7)) # 根据实际情况自己填写
    num_anchor = 9 # 产生族中心点是多少

    # 构建停止条件
    num_iterations=2000
    loss_stop = 1e-6

    centroids, groups, old_loss = init_all_value(1, num_anchor)

    # 循环找到族中最好的w与h
    iterations = 1
    while (True):
    centroids, groups, loss = do_kmeans(num_anchor, boxes, centroids)
    iterations = iterations + 1
    print("loss = %f" % loss)
    if abs(old_loss - loss) < loss_stop or iterations > num_iterations:
    break
    old_loss = loss

    # 打印最终结果

    for centroid in centroids:
    print("k-means result: ")
    print(centroid.w, centroid.h)
  • 相关阅读:
    Proguard打包混淆报错:can't find superclass or interface
    proguard returned with error code 1.异常的解决方法
    android 混淆配置
    解决android混淆编译出现Proguard returned with error code 1和文件名、目录名或卷标语法不正确错误
    Eclipse提示No java virtual machine
    [mysql]数据库查询实例
    [算法]高效求素数
    [笔试]程序员面试宝典
    [linux]进程间通信IPC
    [linux]信号的捕获和处理
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/12393018.html
Copyright © 2011-2022 走看看