zoukankan      html  css  js  c++  java
  • spark 之knn算法

    好长时间忙的没写博客了。看到有人问spark的knn,想着做推荐入门总用的knn算法,顺便写篇博客。

    作者:R星月  http://www.cnblogs.com/rxingyue/p/6182526.html

    knn算法的大致如下:
        1)算距离:给定测试对象,计算它与训练集中的每个对象的距离
        2)找邻居:圈定距离最近的k个训练对象,作为测试对象的近邻
        3)做分类:根据这k个近邻归属的主要类别,来对测试对象分类

    这次用spark实现knn算法。

    首先要加载数据:

    实验就简单点直接模拟:

    List<Node<Integer>> data = new ArrayList<Node<Integer>>();
            for (int i = 0; i < 100; i++) {
                data.add(new Node(String.valueOf(i), i));
            }

    JavaRDD<Node<Integer>> nodes = sc.parallelize(data);
     

    再设计距离的度量,做一个简单的实验如下:

    new SimilarityInterface<Integer>() {
    
                public double similarity(Integer value1, Integer value2) {
                    return 1.0 / (1.0 + Math.abs((Integer) value1 - (Integer) value2));
                }
            };

    距离度量为一个接口可以实现你自己想要的距离计算方法,如cos,欧几里德等等。

    再这要设置你要构建的关联图和设置搜索的近邻k值:

     NNDescent nndes = new NNDescent<Integer>();
            nndes.setK(30);
            nndes.setMaxIterations(4);
            nndes.setSimilarity(similarity);
            // 构建图
            JavaPairRDD<Node, NeighborList> graph = nndes.computeGraph(nodes);

    // 保存文件中
    graph.saveAsTextFile("out/out.txt");

    结果如下: 编号最近的30个值。

    以上就算把knn算法在spark下完成了,剩下要做的就是根据一个数据点进行搜索最相近的k个值。

    搜索:

    final Node<Integer> query = new Node(String.valueOf(111), 50);
    final NeighborList neighborlist_exhaustive
    = exhaustive_search.search(query, 5);

    这段代码是搜索 结点id为111,数值为50最近的5个值。

    结果如下:

    代码很简单:

    /**
     * Created by lsy 983068303@qq.com
     * on 2016/12/15.
     */
    public class TestKnn {
        public static void main(String[] args) throws Exception {
            SparkConf conf = new SparkConf();
            conf.setMaster("local[4]");
            conf.setAppName("knn");
    //        conf.set("spark.executor.memory","1G");
    //        conf.set("spark.storage.memoryFraction","1G");
            JavaSparkContext sc = new JavaSparkContext(conf);
    
            List<Node<Integer>> data = new ArrayList<Node<Integer>>();
            for (int i = 0; i < 100; i++) {
                data.add(new Node(String.valueOf(i), i));
            }
            final SimilarityInterface<Integer> similarity =new SimilarityInterface<Integer>() {
                public double similarity(Integer value1, Integer value2) {
                    return 1.0 / (1.0 + Math.abs((Integer) value1 - (Integer) value2));
                }
            };
            JavaRDD<Node<Integer>> nodes = sc.parallelize(data);
            NNDescent nndes = new NNDescent<Integer>();
            nndes.setK(30);
            nndes.setMaxIterations(4);
            nndes.setSimilarity(similarity);
            JavaPairRDD<Node, NeighborList> graph = nndes.computeGraph(nodes);
    
            graph.saveAsTextFile("out");
            ExhaustiveSearch exhaustive_search
                    = new ExhaustiveSearch(graph, similarity);
            graph.cache();
            final Node<Integer> query = new Node(String.valueOf(111), 50);
            final NeighborList neighborlist_exhaustive
                    = exhaustive_search.search(query, 5);
             for(Neighbor n:neighborlist_exhaustive){
                System.out.print("id编号:"+n.node.id+"==============") ;
                System.out.println("对应的数值:"+n.node.id) ;
             }
            sc.stop();
        }
  • 相关阅读:
    【三中校内训练】怎样更有力气
    【四校联考】立方体
    【四校联考】点
    第11章 卷积神经网络(CNNs)
    第10章神经网络基础
    在jupyter中配置python3
    第9章 优化方法和归一化
    第8章 参数化学习(parameterized learning)
    第7章 你的第一个分类器
    第6章 配置开发环境
  • 原文地址:https://www.cnblogs.com/rxingyue/p/6182526.html
Copyright © 2011-2022 走看看