zoukankan      html  css  js  c++  java
  • 优先队列与TopK

    一、简介

      前文介绍了《最大堆》的实现,本章节在最大堆的基础上实现一个简单的优先队列。优先队列的实现本身没什么难度,所以本文我们从优先队列的场景出发介绍topK问题。

      后面会持续更新数据结构相关的博文。

      数据结构专栏:https://www.cnblogs.com/hello-shf/category/1519192.html

      git传送门:https://github.com/hello-shf/data-structure.git

    二、优先队列

      普通的队列是一种先进先出的数据结构,元素在队列尾追加,而从队列头删除。在优先队列中,元素被赋予优先级。当访问元素时,具有最高优先级的元素最先删除。优先队列具有最高级先出 (first in, largest out)的行为特征。通常采用堆数据结构来实现。

      上面是百度百科给出的优先队列的解释。解释的还是很到位的。具体的优先队列的实现可以采用最小堆或者最大堆。因为在我们前文《最大堆》的实现中,该堆存储的元素是要求实现Comparable接口的。所以优先级是掌握在用户手中的,所以最小堆和最大堆都可以作为优先队列的底层数据结构。

      普通的队列Queue,我们都知道是先进先出(FIFO)的,所以元素的出队顺序和入队顺序是保持一致的。但是对于我们的优先队列,出队操作,将不保证先进先出的队列特性,而是根据元素的优先级(或者说权重)决定出队的顺序。

      如果最小元素拥有最高的优先级,那么这种优先队列叫作升序优先队列,即总是优先删除最小的元素。同理,如果最大元素拥有最高的优先级,那么这种优先队列叫作降序优先队列,即总是先删除最大的元素。

      优先队列的使用场景:

     算法场景:
         最短路径算法:Dijkstra算法
         最小生成树算法:Prim算法
         事件驱动仿真:顾客排队算法
         选择问题:查找第k个最小元素
     实现场景:
         游戏中优先攻击最近单位,优先攻击血量最低等
         售票窗口的老幼病残孕和军人优先购票等

    三、优先队列的实现

    3.1、队列接口定义

      同普通的队列,我们先定义队列的接口如下

     1 /**
     2  * 描述:队列
     3  *
     4  * @Author shf
     5  * @Date 2019/7/18 15:30
     6  * @Version V1.0
     7  **/
     8 public interface Queue<E> {
     9     /**
    10      * 获取当前队列的元素数
    11      * @return
    12      */
    13     int getSize();
    14 
    15     /**
    16      * 判断当前队列是否为空
    17      * @return
    18      */
    19     boolean isEmpty();
    20 
    21     /**
    22      * 入队操作
    23      * @param e
    24      */
    25     void enqueue(E e);
    26 
    27     /**
    28      * 出队操作
    29      * @return
    30      */
    31     E dequeue();
    32 
    33     /**
    34      * 获取队列头元素
    35      * @return
    36      */
    37     E getFront();
    38 }

      

    3.2、最大堆实现的优先队列

      我们使用前文实现的《最大堆》来实现一个优先队列

     1 /**
     2  * 描述:优先队列
     3  *
     4  * @Author shf
     5  * @Date 2019/7/18 17:31
     6  * @Version V1.0
     7  **/
     8 public class PriorityQueue<E extends Comparable<E>> implements Queue<E> {
     9 
    10     private MaxHeap<E> maxHeap;
    11 
    12     public PriorityQueue(){
    13         maxHeap = new MaxHeap<>();
    14     }
    15 
    16     @Override
    17     public int getSize(){
    18         return maxHeap.size();
    19     }
    20 
    21     @Override
    22     public boolean isEmpty(){
    23         return maxHeap.isEmpty();
    24     }
    25 
    26     @Override
    27     public E getFront(){
    28         // 获取队列的头元素,在最大堆中就是获取堆顶元素
    29         return maxHeap.findMax();
    30     }
    31 
    32     @Override
    33     public void enqueue(E e){
    34         // 压栈 直接向最大堆中添加,让最大堆的add方法维护 元素的优先级
    35         maxHeap.add(e);
    36     }
    37 
    38     @Override
    39     public E dequeue(){
    40         // 出栈 将最大堆的堆顶元素取出
    41         return maxHeap.extractMax();
    42     }
    43 }

      需要解释的都在代码注释中了。

      到这里优先队列就实现完了,是不是很简单。

      在java中也有一个类PriorityQueue,其底层是采用的最小堆实现的优先队列。在java PriorityQueue中关于优先级的定义,优先级队列的元素按照其自然顺序进行排序,或者根据构造队列时提供的 Comparator 进行排序,具体取决于所使用的构造方法。底层数据结构最大堆或者最小堆是没有什么区别的。关键在于我们如何定义优先级。

    四、topK问题

      关于topK问题,leetcode上面有一道典型的题目

      题目最终需要返回的是前 k 个频率最大的元素,可以想到借助堆这种数据结构,对于 k 频率之后的元素不用再去处理,进一步优化时间复杂度。

      具体操作为:

    借助 哈希表 来建立数字和其出现次数的映射,遍历一遍数组统计元素的频率
    维护一个元素数目为 k 的最小堆
    每次都将新的元素与堆顶元素(堆中频率最小的元素)进行比较
    如果新的元素的频率比堆顶端的元素大,则弹出堆顶端的元素,将新的元素添加进堆中
    最终,堆中的 k 个元素即为前 k 个高频元素

       具体实现

     1 class Solution {
     2     public List<Integer> topKFrequent(int[] nums, int k) {
     3         // 使用字典,统计每个元素出现的次数,元素为键,元素出现的次数为值
     4         HashMap<Integer,Integer> map = new HashMap();
     5         for(int num : nums){
     6             if (map.containsKey(num)) {
     7                map.put(num, map.get(num) + 1);
     8              } else {
     9                 map.put(num, 1);
    10              }
    11         }
    12         // 遍历map,用最小堆保存频率最大的k个元素
    13         PriorityQueue<Integer> pq = new PriorityQueue<>(new Comparator<Integer>() {
    14             @Override
    15             public int compare(Integer a, Integer b) {
    16                 return map.get(a) - map.get(b);
    17             }
    18         });
    19         for (Integer key : map.keySet()) {
    20             if (pq.size() < k) {
    21                 pq.add(key);
    22             } else if (map.get(key) > map.get(pq.peek())) {
    23                 pq.remove();
    24                 pq.add(key);
    25             }
    26         }
    27         // 取出最小堆中的元素
    28         List<Integer> res = new ArrayList<>();
    29         while (!pq.isEmpty()) {
    30             res.add(pq.remove());
    31         }
    32         return res;
    33     }
    34 }

       以上是使用java原生的优先队列实现的。接下来我们用我们自己实现的PriorityQueue试验一下。

       首先因为我们没有提供接收一个Comparator的构造器,所以我们通过定义一个类来完成这个过程比较。

      因为自己定义的优先队列底层使用的是我们自己实现的最大堆,以及最大堆底层数组也是使用自己定义的,所以我们在leetcode提交验证的时候,需要将这些自定义的类以内部类的方式提交上去。整体代码如下

      1 /// 347. Top K Frequent Elements
      2 /// https://leetcode.com/problems/top-k-frequent-elements/description/
      3 
      4 import java.util.LinkedList;
      5 import java.util.List;
      6 import java.util.TreeMap;
      7 
      8 class Solution {
      9 
     10     private class Array<E> {
     11 
     12         private E[] data;
     13         private int size;
     14 
     15         // 构造函数,传入数组的容量capacity构造Array
     16         public Array(int capacity){
     17             data = (E[])new Object[capacity];
     18             size = 0;
     19         }
     20 
     21         // 无参数的构造函数,默认数组的容量capacity=10
     22         public Array(){
     23             this(10);
     24         }
     25 
     26         public Array(E[] arr){
     27             data = (E[])new Object[arr.length];
     28             for(int i = 0 ; i < arr.length ; i ++)
     29                 data[i] = arr[i];
     30             size = arr.length;
     31         }
     32 
     33         // 获取数组的容量
     34         public int getCapacity(){
     35             return data.length;
     36         }
     37 
     38         // 获取数组中的元素个数
     39         public int getSize(){
     40             return size;
     41         }
     42 
     43         // 返回数组是否为空
     44         public boolean isEmpty(){
     45             return size == 0;
     46         }
     47 
     48         // 在index索引的位置插入一个新元素e
     49         public void add(int index, E e){
     50 
     51             if(index < 0 || index > size)
     52                 throw new IllegalArgumentException("Add failed. Require index >= 0 and index <= size.");
     53 
     54             if(size == data.length)
     55                 resize(2 * data.length);
     56 
     57             for(int i = size - 1; i >= index ; i --)
     58                 data[i + 1] = data[i];
     59 
     60             data[index] = e;
     61 
     62             size ++;
     63         }
     64 
     65         // 向所有元素后添加一个新元素
     66         public void addLast(E e){
     67             add(size, e);
     68         }
     69 
     70         // 在所有元素前添加一个新元素
     71         public void addFirst(E e){
     72             add(0, e);
     73         }
     74 
     75         // 获取index索引位置的元素
     76         public E get(int index){
     77             if(index < 0 || index >= size)
     78                 throw new IllegalArgumentException("Get failed. Index is illegal.");
     79             return data[index];
     80         }
     81 
     82         // 修改index索引位置的元素为e
     83         public void set(int index, E e){
     84             if(index < 0 || index >= size)
     85                 throw new IllegalArgumentException("Set failed. Index is illegal.");
     86             data[index] = e;
     87         }
     88 
     89         // 查找数组中是否有元素e
     90         public boolean contains(E e){
     91             for(int i = 0 ; i < size ; i ++){
     92                 if(data[i].equals(e))
     93                     return true;
     94             }
     95             return false;
     96         }
     97 
     98         // 查找数组中元素e所在的索引,如果不存在元素e,则返回-1
     99         public int find(E e){
    100             for(int i = 0 ; i < size ; i ++){
    101                 if(data[i].equals(e))
    102                     return i;
    103             }
    104             return -1;
    105         }
    106 
    107         // 从数组中删除index位置的元素, 返回删除的元素
    108         public E remove(int index){
    109             if(index < 0 || index >= size)
    110                 throw new IllegalArgumentException("Remove failed. Index is illegal.");
    111 
    112             E ret = data[index];
    113             for(int i = index + 1 ; i < size ; i ++)
    114                 data[i - 1] = data[i];
    115             size --;
    116             data[size] = null; // loitering objects != memory leak
    117 
    118             if(size == data.length / 4 && data.length / 2 != 0)
    119                 resize(data.length / 2);
    120             return ret;
    121         }
    122 
    123         // 从数组中删除第一个元素, 返回删除的元素
    124         public E removeFirst(){
    125             return remove(0);
    126         }
    127 
    128         // 从数组中删除最后一个元素, 返回删除的元素
    129         public E removeLast(){
    130             return remove(size - 1);
    131         }
    132 
    133         // 从数组中删除元素e
    134         public void removeElement(E e){
    135             int index = find(e);
    136             if(index != -1)
    137                 remove(index);
    138         }
    139 
    140         public void swap(int i, int j){
    141 
    142             if(i < 0 || i >= size || j < 0 || j >= size)
    143                 throw new IllegalArgumentException("Index is illegal.");
    144 
    145             E t = data[i];
    146             data[i] = data[j];
    147             data[j] = t;
    148         }
    149 
    150         @Override
    151         public String toString(){
    152 
    153             StringBuilder res = new StringBuilder();
    154             res.append(String.format("Array: size = %d , capacity = %d
    ", size, data.length));
    155             res.append('[');
    156             for(int i = 0 ; i < size ; i ++){
    157                 res.append(data[i]);
    158                 if(i != size - 1)
    159                     res.append(", ");
    160             }
    161             res.append(']');
    162             return res.toString();
    163         }
    164 
    165         // 将数组空间的容量变成newCapacity大小
    166         private void resize(int newCapacity){
    167 
    168             E[] newData = (E[])new Object[newCapacity];
    169             for(int i = 0 ; i < size ; i ++)
    170                 newData[i] = data[i];
    171             data = newData;
    172         }
    173     }
    174 
    175     private class MaxHeap<E extends Comparable<E>> {
    176 
    177         private Array<E> data;
    178 
    179         public MaxHeap(int capacity){
    180             data = new Array<>(capacity);
    181         }
    182 
    183         public MaxHeap(){
    184             data = new Array<>();
    185         }
    186 
    187         public MaxHeap(E[] arr){
    188             data = new Array<>(arr);
    189             for(int i = parent(arr.length - 1) ; i >= 0 ; i --)
    190                 siftDown(i);
    191         }
    192 
    193         // 返回堆中的元素个数
    194         public int size(){
    195             return data.getSize();
    196         }
    197 
    198         // 返回一个布尔值, 表示堆中是否为空
    199         public boolean isEmpty(){
    200             return data.isEmpty();
    201         }
    202 
    203         // 返回完全二叉树的数组表示中,一个索引所表示的元素的父亲节点的索引
    204         private int parent(int index){
    205             if(index == 0)
    206                 throw new IllegalArgumentException("index-0 doesn't have parent.");
    207             return (index - 1) / 2;
    208         }
    209 
    210         // 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
    211         private int leftChild(int index){
    212             return index * 2 + 1;
    213         }
    214 
    215         // 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
    216         private int rightChild(int index){
    217             return index * 2 + 2;
    218         }
    219 
    220         // 向堆中添加元素
    221         public void add(E e){
    222             data.addLast(e);
    223             siftUp(data.getSize() - 1);
    224         }
    225 
    226         private void siftUp(int k){
    227 
    228             while(k > 0 && data.get(parent(k)).compareTo(data.get(k)) < 0 ){
    229                 data.swap(k, parent(k));
    230                 k = parent(k);
    231             }
    232         }
    233 
    234         // 看堆中的最大元素
    235         public E findMax(){
    236             if(data.getSize() == 0)
    237                 throw new IllegalArgumentException("Can not findMax when heap is empty.");
    238             return data.get(0);
    239         }
    240 
    241         // 取出堆中最大元素
    242         public E extractMax(){
    243 
    244             E ret = findMax();
    245 
    246             data.swap(0, data.getSize() - 1);
    247             data.removeLast();
    248             siftDown(0);
    249 
    250             return ret;
    251         }
    252 
    253         private void siftDown(int k){
    254 
    255             while(leftChild(k) < data.getSize()){
    256                 int j = leftChild(k); // 在此轮循环中,data[k]和data[j]交换位置
    257                 if( j + 1 < data.getSize() &&
    258                         data.get(j + 1).compareTo(data.get(j)) > 0 )
    259                     j ++;
    260                 // data[j] 是 leftChild 和 rightChild 中的最大值
    261 
    262                 if(data.get(k).compareTo(data.get(j)) >= 0 )
    263                     break;
    264 
    265                 data.swap(k, j);
    266                 k = j;
    267             }
    268         }
    269 
    270         // 取出堆中的最大元素,并且替换成元素e
    271         public E replace(E e){
    272 
    273             E ret = findMax();
    274             data.set(0, e);
    275             siftDown(0);
    276             return ret;
    277         }
    278     }
    279 
    280     private interface Queue<E> {
    281 
    282         int getSize();
    283         boolean isEmpty();
    284         void enqueue(E e);
    285         E dequeue();
    286         E getFront();
    287     }
    288 
    289     private class PriorityQueue<E extends Comparable<E>> implements Queue<E> {
    290 
    291         private MaxHeap<E> maxHeap;
    292 
    293         public PriorityQueue(){
    294             maxHeap = new MaxHeap<>();
    295         }
    296 
    297         @Override
    298         public int getSize(){
    299             return maxHeap.size();
    300         }
    301 
    302         @Override
    303         public boolean isEmpty(){
    304             return maxHeap.isEmpty();
    305         }
    306 
    307         @Override
    308         public E getFront(){
    309             return maxHeap.findMax();
    310         }
    311 
    312         @Override
    313         public void enqueue(E e){
    314             maxHeap.add(e);
    315         }
    316 
    317         @Override
    318         public E dequeue(){
    319             return maxHeap.extractMax();
    320         }
    321     }
    322 
    323     private class Freq implements Comparable<Freq>{
    324 
    325         public int e, freq;
    326 
    327         public Freq(int e, int freq){
    328             this.e = e;
    329             this.freq = freq;
    330         }
    331 
    332         @Override
    333         public int compareTo(Freq another){
    334             if(this.freq < another.freq)
    335                 return 1;
    336             else if(this.freq > another.freq)
    337                 return -1;
    338             else
    339                 return 0;
    340         }
    341     }
    342 
    343     public List<Integer> topKFrequent(int[] nums, int k) {
    344 
    345         TreeMap<Integer, Integer> map = new TreeMap<>();
    346         for(int num: nums){
    347             if(map.containsKey(num))
    348                 map.put(num, map.get(num) + 1);
    349             else
    350                 map.put(num, 1);
    351         }
    352 
    353         PriorityQueue<Freq> pq = new PriorityQueue<>();
    354         for(int key: map.keySet()){
    355             if(pq.getSize() < k)
    356                 pq.enqueue(new Freq(key, map.get(key)));
    357             else if(map.get(key) > pq.getFront().freq){
    358                 pq.dequeue();
    359                 pq.enqueue(new Freq(key, map.get(key)));
    360             }
    361         }
    362 
    363         LinkedList<Integer> res = new LinkedList<>();
    364         while(!pq.isEmpty())
    365             res.add(pq.dequeue().e);
    366         return res;
    367     }
    368 
    369     private static void printList(List<Integer> nums){
    370         for(Integer num: nums)
    371             System.out.print(num + " ");
    372         System.out.println();
    373     }
    374 
    375     public static void main(String[] args) {
    376 
    377         int[] nums = {1, 1, 1, 2, 2, 3};
    378         int k = 2;
    379         printList((new Solution()).topKFrequent(nums, k));
    380     }
    381 }
    View Code

       在以上代码中我们需要关心的是如下部分

     1   private class Freq implements Comparable<Freq>{
     2 
     3         public int e, freq;
     4 
     5         public Freq(int e, int freq){
     6             this.e = e;
     7             this.freq = freq;
     8         }
     9 
    10         @Override
    11         public int compareTo(Freq another){
    12             if(this.freq < another.freq)
    13                 return 1;
    14             else if(this.freq > another.freq)
    15                 return -1;
    16             else
    17                 return 0;
    18         }
    19     }
    20 
    21     public List<Integer> topKFrequent(int[] nums, int k) {
    22 
    23         TreeMap<Integer, Integer> map = new TreeMap<>();
    24         for(int num: nums){
    25             if(map.containsKey(num))
    26                 map.put(num, map.get(num) + 1);
    27             else
    28                 map.put(num, 1);
    29         }
    30 
    31         PriorityQueue<Freq> pq = new PriorityQueue<>();
    32         for(int key: map.keySet()){
    33             if(pq.getSize() < k)
    34                 pq.enqueue(new Freq(key, map.get(key)));
    35             else if(map.get(key) > pq.getFront().freq){
    36                 pq.dequeue();
    37                 pq.enqueue(new Freq(key, map.get(key)));
    38             }
    39         }
    40 
    41         LinkedList<Integer> res = new LinkedList<>();
    42         while(!pq.isEmpty())
    43             res.add(pq.dequeue().e);
    44         return res;
    45     }

       我们将完整代码提交到leetcode

      得到如下结果,表示我们验证自己实现的优先队列成功了。

      这盛世,如您所愿。

      如有错误的地方还请留言指正。

      原创不易,转载请注明原文地址:https://www.cnblogs.com/hello-shf/p/11397386.html 

  • 相关阅读:
    腾讯视频插入网页的代码;
    FW: 软件持续交付的诉求;
    TOGAF
    Windows WSL2 htop打开黑屏的问题解决
    requests.exceptions.ConnectionError: HTTPSConnectionPool(host='appts.xxx.com%20', port=443):
    sqlalchemy实现模糊查询
    jenkins过滤版本,可选择版本
    QML 布局之一:锚布局详解(各种例子)
    Qt Quick 常用控件:Button(按钮)用法及自定义
    The Common Order Operations of Dis Operation System (DOSS)
  • 原文地址:https://www.cnblogs.com/hello-shf/p/11397386.html
Copyright © 2011-2022 走看看