zoukankan      html  css  js  c++  java
  • Programming Assignment 5: Kd-Trees

    用2d-tree数据结构实现在2维矩形区域内的高效的range search 和 nearest neighbor search。2d-tree有许多的应用,在天体分类、计算机动画、神经网络加速、数据挖掘、图像检索。

    range search: 返回所有在query rectangle里的所有点

    nearest neighbor search: 返回query point的最近点

    下图显示这两种search操作

    从

    Geometric Primitives. 在assignment给定了几何图元应该如何表示,如下图

    从

    其中关于Point和Rectangle的表示已经定义在了Point2D.java和RectHV.java中,API都已经提供了,这个都不用自己实现。

    Point2D的API主要是点的坐标、平方距离、欧几里得距离、点的比较、绘制等,

    RectHV主要是一个2维的包围盒,记录矩形的左下角和右上角点的信息,主要API是contains(Point2D)判断点是否在矩形内,intersects(RectHV)是否与另一个矩形相交,以及矩形到点的平方距离和距离,绘制等。源码都可以找到来进行分析。

    下面就是要完成的两个任务:Brute-force 实现 和 2d-tree 实现。

    需要实现的API是一样的,这里以PointSET为例子,2d-tree也一样:

    public class PointSET {
           public PointSET()                               // construct an empty set of points
           public boolean isEmpty()                        // is the set empty?
           public int size()                               // number of points in the set
           public void insert(Point2D p)                   // add the point p to the set (if it is not already in the set)
           public boolean contains(Point2D p)              // does the set contain the point p?
           public void draw()                              // draw all of the points to standard draw
           public Iterable<Point2D> range(RectHV rect)     // all points in the set that are inside the rectangle
           public Point2D nearest(Point2D p)               // a nearest neighbor in the set to p; null if set is empty
    }

     

    Brute-force:暴力的实现需要insert()和contains()是在O(logn)的复杂度,nearest()和range()是O(N)的复杂度。

    这里用algs4.jar的SET来实现,代码很简单。

    range(RectHV): 遍历SET中所有Point与当前RectHV进行包含关系判断

    nearest(Point2D):遍历SET中的所有Point与当前Point进行距离判断,不断更新最小距离和最小距离的点,在进行距离判断的时候,用平方距离,开方会影响计算速度。

    public class PointSET {
        private SET<Point2D> set;
        
        // construct an empty set of points
        public PointSET() {                               
            set = new SET<Point2D>();
        }
        
        // is the set empty?
        public boolean isEmpty() {
            return set.isEmpty();
        }                      
       
        // number of points in the set
        public int size() {
            return set.size();
        }
        
        // add the point p to the set (if it is not already in the set)
        public void insert(Point2D p) {
            set.add(p);
        }
        
        // does the set contain the point p?
        public boolean contains(Point2D p) {
            return set.contains(p);
        }
        
        // draw all of the points to standard draw
        public void draw() {
            for (Point2D p : set) {
                StdDraw.point(p.x(), p.y());
            }
        }                              
        
        // all points in the set that are inside the rectangle
        public Iterable<Point2D> range(RectHV rect) {
            Queue<Point2D> q = new Queue<Point2D>();
            for (Point2D p : set) {
                if (rect.contains(p))
                    q.enqueue(p);
            }
            return q;
        }
        
        // a nearest neighbor in the set to p; null if set is empty
        public Point2D nearest(Point2D p) {
            double mindis = Double.MAX_VALUE;
            Point2D ret = null;
            for (Point2D s : set) {
                double dis = s.distanceSquaredTo(p);
                if (dis < mindis) {
                    mindis = dis;
                    ret = s;
                }
            }
            return ret;
        }
    }

    2d-tree:这里是使用BST为结构对节点进行组织,每个节点记录下面的相关属性,这个在Possible Progress Step中有提示。通过assignment中的描述和图可以对它有很清晰的认识。

    从

    Node节点定义如下:

    p记录当前点,rect记录当前点的“包围盒”(轴平行矩阵),lb记录左边或者下边的区域节点,rt记录右边或者上边的区域节点。

    private static class Node {
        private Point2D p; // the point
        // the axis-aligned rectangle corresponding to this node
        // the max rectangle include this node, aabb
        private RectHV rect; 
        private Node lb; // the left/bottom subtree
        private Node rt; // the right/top subtree
        public Node(Point2D p, RectHV rect) {
            this.p = p;
            this.rect = rect;
            lb = null;
            rt = null;
        }
    }

    2d-Tree的具体实现只要参考BST的写法就很好实现,insert的时候原本写的是new RectHV,不断进行递归进行构造,但是new的太多,fail test了。后面在insert中直接把RectHV的4个坐标作为参数在Insert中进行递归。

    还有一个比较重要的问题是,在insert,get,draw中,要把方向orientation作为参数,用来标示当前应该是左右分还是上下分,draw,insert和get都参照BST的写法,递归实现是十分简洁的。

    range()和nearest()都采用BFS广度搜索的方法,遍历这个2d-tree,进行相交和包含的判断,维护有效的节点信息。nearest()也记得使用平方距离,开方影响运行时间。

    代码实现如下:

    public class KdTree {
        
        private Node root;
        private int N;
        private static class Node {
            private Point2D p; // the point
            // the axis-aligned rectangle corresponding to this node
            // the max rectangle include this node, aabb
            private RectHV rect; 
            private Node lb; // the left/bottom subtree
            private Node rt; // the right/top subtree
            public Node(Point2D p, RectHV rect) {
                this.p = p;
                this.rect = rect;
                lb = null;
                rt = null;
            }
        }
        
        private final RectHV CANVAS = new RectHV(0, 0, 1, 1);
        
        // construct an empty set of points
        public KdTree() {
            root = null;
            N = 0;
        }
        
        // is the set empty?
        public boolean isEmpty() {
            return N == 0;
        }                      
       
        // number of points in the set
        public int size() {
            return N;
        }
        
        /**************************************
          * less
          * compare two Point2D with orientation
          *************************************/
        private int compareTo(Point2D v, Point2D w, int ori) {
            if (v.equals(w)) return 0; // same point
            else {
                if (ori == 0) { 
                    // vertical line
                    if (v.x() < w.x()) return -1;
                    else return 1;
                } else {
                    // horizontal line
                    if (v.y() < w.y()) return -1;
                    else return 1;
                }
            }
        }
        
        /***********************************************
         * Insert
         **********************************************/
        
        private Node insert(Node x, Point2D p, 
                            double xmin, double ymin, double xmax, double ymax, 
                            int ori) {
            if (x == null) {
                N++;
                return new Node(p, new RectHV(xmin, ymin, xmax, ymax));
            }
            int cmp = compareTo(p, x.p, ori);
            double x0 = xmin, y0 = ymin, x1 = xmax, y1 = ymax;
            if (cmp < 0) {
                if (ori == 0) x1 = x.p.x();
                else y1 = x.p.y();
                x.lb = insert(x.lb, p, x0, y0, x1, y1, 1-ori);
            }
            else if (cmp > 0) {
                if (ori == 0) x0 = x.p.x();
                else y0 = x.p.y();
                x.rt = insert(x.rt, p, x0, y0, x1, y1, 1-ori);
            }
            return x;
        }
        
        // add the point p to the set (if it is not already in the set)
        public void insert(Point2D p) {
            // 0 for vertical, 1 for horizontal
            root = insert(root, p, 
                          CANVAS.xmin(), CANVAS.ymin(),
                          CANVAS.xmax(), CANVAS.ymax(), 0);
        }
        
        /*******************************************
          * contains
          *****************************************/
        private boolean get(Node x, Point2D p, int ori) {
            if (x == null) return false;
            int cmp = compareTo(p, x.p, ori);
            if (cmp < 0) return get(x.lb, p, 1-ori);
            else if (cmp > 0) return get(x.rt, p, 1-ori);
            return true;
        }
        
        // does the set contain the point p?
        public boolean contains(Point2D p) {
            // 0 for vertical, 1 for horizontal
            return get(root, p, 0);
        }
        
        /***************************************
          * Draw()
          *************************************/
        private void draw(Node x, int ori) {
            if (x == null) return;
            // draw point
            StdDraw.setPenColor(StdDraw.BLACK);
            StdDraw.setPenRadius(.01);
            StdDraw.point(x.p.x(), x.p.y());
            // draw line
            if (ori == 0) {
                // vertical
                 StdDraw.setPenColor(StdDraw.RED);
                 StdDraw.setPenRadius();
                 StdDraw.line(x.p.x(), x.rect.ymin(), x.p.x(), x.rect.ymax());
            } else {
                // horizontal
                StdDraw.setPenColor(StdDraw.BLUE);
                StdDraw.setPenRadius();
                StdDraw.line(x.rect.xmin(), x.p.y(), x.rect.xmax(), x.p.y());
            }
            draw(x.lb, 1-ori);
            draw(x.rt, 1-ori);
        }
        
        // draw all of the points to standard draw
        public void draw() {
            StdDraw.setScale(0, 1);  
            StdDraw.setPenColor(StdDraw.BLACK);
            StdDraw.setPenRadius();
            CANVAS.draw();
            draw(root, 0);
        }                              
        
        // all points in the set that are inside the rectangle
        public Iterable<Point2D> range(RectHV rect) {
            Queue<Point2D> points = new Queue<Point2D>();
            Queue<Node> queue = new Queue<Node>();
            if (root == null) return points;
            queue.enqueue(root);
            while (!queue.isEmpty()) {
                Node x = queue.dequeue();
                if (x == null) continue;
                if (rect.contains(x.p)) points.enqueue(x.p);
                if (x.lb != null && rect.intersects(x.lb.rect)) queue.enqueue(x.lb);
                if (x.rt != null && rect.intersects(x.rt.rect)) queue.enqueue(x.rt);
            }
            return points;
        }
        
        // a nearest neighbor in the set to p; null if set is empty
        public Point2D nearest(Point2D p) {
            if (root == null) return null;
            Point2D retp = null;
            double mindis = Double.MAX_VALUE;
            Queue<Node> queue = new Queue<Node>();
            queue.enqueue(root);
            while (!queue.isEmpty()) {
                Node x = queue.dequeue();
                double dis = p.distanceSquaredTo(x.p);
                if (dis < mindis) {
                    retp = x.p;
                    mindis = dis; 
                }
                if (x.lb != null && x.lb.rect.distanceSquaredTo(p) < mindis) 
                    queue.enqueue(x.lb);
                if (x.rt != null && x.rt.rect.distanceSquaredTo(p) < mindis) 
                    queue.enqueue(x.rt);
            }
            return retp;
        }
       
    }

    总结:Last words, 这应该是第一门坚持上完的公开课吧,原来Andrew Ng的ML上了一半后,由于事情太多就把课给荒废了(现在又重新开始新一轮了,fighting!吐舌笑脸)。

    可能这几个Assignment写的都不咋地,但记录回顾一下,还是觉得很有收获。特别感谢Prof.Sedgewick和Coursera平台,给予了一段精彩的旅程。后面的Part II到时候继续跟上。

    不得不感叹,国外的MOOC平台做的相当的完美,提供了这么多好的资源,国内估计也有类似的吧,没去具体了解过。一定程度上真是把大学搬进了家里,不过感觉仅凭MOOC上几周课程来对领域或者部分的知识,作为一个较为(较为深入?)了解比较恰当,如果要熟练运用和掌握,还需要很长的路要走,Study hungry! Study foolish!

  • 相关阅读:
    13---Net基础加强
    12---Net基础加强
    11---Net基础加强
    10---Net基础加强
    09---Net基础加强
    08---Net基础加强
    07---Net基础加强
    06---Net基础加强
    05---Net基础加强
    04---Net基础加强
  • 原文地址:https://www.cnblogs.com/tiny656/p/3873510.html
Copyright © 2011-2022 走看看