zoukankan      html  css  js  c++  java
  • (转)Mahout Kmeans Clustering 学习

    一、Mahout命令使用

    合成控制的数据集 synthetic_control.data 可以从 此处下载,总共由600行X60列double型的数据组成, 意思是有600个元组,每个元组是一个时间序列。

    1. 把数据拷到集群上,放到kmeans/目录下

      
    hadoop fs -mv synthetic_control.data kmeans/synthetic_control.data
     

     

    2. 输入如下mahout命令进行KMeans聚类分析

     

    mahout org.apache.mahout.clustering.syntheticcontrol.kmeans.Job --input kmeans/synthetic_control.data  --numClusters 3 -t1 3 -t2 6 --maxIter 3 --output kmeans/output

     当命令中有这个--numClusters( 代表聚类结果中簇的个数)参数的话,它会采用Kmeans聚类。如果没有配置这个参数的话,它会先采用Canopy聚类,-t1和-t2是用于Canopy聚类的配置参数。

     

    二、源码学习

    从Mahout源码可以分析出:进行KMeans聚类时,会产生四个步骤。

    1. 数据预处理,整理规范化数据
    2. 从上述数据中随机选择若干个数据当作Cluster的中心
    3. 迭代计算,调整形心
    4. 把数据分给各个Cluster

    其中 前俩步就是 KMeans聚类算法的准备工作。

    主要流程可以从org.apache.mahout.clustering.syntheticcontrol.kmeans.Job#run()方法里看出一些端倪。

      

      public static void run(Configuration conf, Path input, Path output, DistanceMeasure measure, int k,
          double convergenceDelta, int maxIterations) throws Exception {
        //1. synthetic_control.data存储的文本格式,转换成Key/Value格式,存入到output/data目录。Key为保存一个Integer的Text类型, Value为VectorWritable类型。
        Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
        log.info("Preparing Input");
        InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
        //2. 随机产生几个cluster,存入到output/clusters-0/part-randomSeed文件里。Key为Text, Value为ClusterWritable类型。
        log.info("Running random seed to get initial clusters");
        Path clusters = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
        clusters = RandomSeedGenerator.buildRandom(conf, directoryContainingConvertedInput, clusters, k, measure);
        //3. 进行聚类迭代运算,为每一个簇重新选出cluster centroid中心
        log.info("Running KMeans");
        KMeansDriver.run(conf, directoryContainingConvertedInput, clusters, output, measure, convergenceDelta,
            maxIterations, true, 0.0, false);
        //4. 根据上面选出的中心,把output/data里面的记录,都分配给各个cluster。输出运算结果,把sequencefile格式转化成textfile格式展示出来
        // run ClusterDumper
        ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-*-final"), new Path(output,
            "clusteredPoints"));
        clusterDumper.printClusters(null);
      }
    1. RandomAccessSparseVector是一个Vector实现,里面有一个 OpenIntDoubleMap属性,该OpenIntDoubleMap不是继承自HashMap,而是自己实现了一套类似的hashMap,数据是通过一个Int数组和Long数组维护着,因此无法通过Iterator为遍历。
    2. RandomSeedGenerator#buildRandom()是在上面的Vector里面随机抽样k个序列簇Kluster,采用的是一种蓄水池抽样(Reservoir Sampling)的方法:即先把前k个数放入蓄水池,对第k+1,我们以k/(k+1)概率决定是否要把它换入蓄水池,最终每个数都是以相同的概率k/n进入蓄水池。它通过强大的MersenneTwister伪随机生成器来随机产生,它产生的随机数长度可达2^19937 - 1,维度可高达623维,同时数值还可以精确到32位的均匀分布。

    1. 迭代计算准备工作

    真正在做KMeans聚类的代码是:
     
      public static Path buildClusters(Configuration conf, Path input, Path clustersIn, Path output,
          DistanceMeasure measure, int maxIterations, String delta, boolean runSequential) throws IOException,
          InterruptedException, ClassNotFoundException {
        
        double convergenceDelta = Double.parseDouble(delta);
        //从output/clusters-0/part-randomSeed文件里读出Cluster数据,放入到clusters变量中。
        List<Cluster> clusters = Lists.newArrayList();
        KMeansUtil.configureWithClusterInfo(conf, clustersIn, clusters);
        
        if (clusters.isEmpty()) {
          throw new IllegalStateException("No input clusters found in " + clustersIn + ". Check your -c argument.");
        }
        //把聚类策略(控制收敛程度)写进output/clusters-0/_policy文件中
        //同时,每个簇cluster在output/clusters-0/下对应生成part-000xx文件
        Path priorClustersPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
        ClusteringPolicy policy = new KMeansClusteringPolicy(convergenceDelta);
        ClusterClassifier prior = new ClusterClassifier(clusters, policy);
        prior.writeToSeqFiles(priorClustersPath);
        //开始迭代maxIterations次执行Map/Reduce
        if (runSequential) {
          ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations);
        } else {
          ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations);
        }
        return output;
      }
      

    2. 迭代计算

    调整cluster中心的Job的代码如下:

     
      public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations)
        throws IOException, InterruptedException, ClassNotFoundException {
        ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath);
        Path clustersOut = null;
        int iteration = 1;
        while (iteration <= numIterations) {
          conf.set(PRIOR_PATH_KEY, priorPath.toString());
          
          String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath;
          Job job = new Job(conf, jobName);
          job.setMapOutputKeyClass(IntWritable.class);
          job.setMapOutputValueClass(ClusterWritable.class);
          job.setOutputKeyClass(IntWritable.class);
          job.setOutputValueClass(ClusterWritable.class);
          
          job.setInputFormatClass(SequenceFileInputFormat.class);
          job.setOutputFormatClass(SequenceFileOutputFormat.class);
          //核心算法就在这个CIMapper和CIReducer里面
          job.setMapperClass(CIMapper.class);
          job.setReducerClass(CIReducer.class);
          
          FileInputFormat.addInputPath(job, inPath);
          clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);
          priorPath = clustersOut;
          FileOutputFormat.setOutputPath(job, clustersOut);
          
          job.setJarByClass(ClusterIterator.class);
          if (!job.waitForCompletion(true)) {
            throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath);
          }
          ClusterClassifier.writePolicy(policy, clustersOut);
          FileSystem fs = FileSystem.get(outPath.toUri(), conf);
          iteration++;
          if (isConverged(clustersOut, conf, fs)) {
            break;
          }
        }
        //把最后一次迭代的结果目录重命名,加一个final
        Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);
        FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn);
      }

      

    2.1. Map阶段

    CIMapper代码如下:

     

     
     @Override
      protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException,
          InterruptedException {
        Vector probabilities = classifier.classify(value.get());
        Vector selections = policy.select(probabilities);
        for (Iterator<Element> it = selections.iterateNonZero(); it.hasNext();) {
          Element el = it.next();
          classifier.train(el.index(), value.get(), el.get());
        }
      }
     

     

    在这里面需要厘清

    org.apache.mahout.clustering.iterator.KMeansClusteringPolicy

    org.apache.mahout.clustering.classify.ClusterClassifier

    这两个类。

    前者是聚类的策略,可以说它提供聚类的核心算法。

    后者是聚类的分类器,它的功能是基于聚类策略把数据进行分类。

    2.1.1. ClusterClassifier 求点到Cluster形心的距离

     ClusterClassifier.classify()求得某点到所有cluster中心的距离,得到的是一个数组。

     
    @Override
      public Vector classify(Vector data, ClusterClassifier prior) {
        List<Cluster> models = prior.getModels();
        int i = 0;
        Vector pdfs = new DenseVector(models.size());
        for (Cluster model : models) {
          pdfs.set(i++, model.pdf(new VectorWritable(data)));
        }
        return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
      }
     

    上述代码中的org.apache.mahout.clustering.iterator.DistanceMeasureCluster.pdf(VectorWritable)求该点到Cluster形心的距离,其算法代码如下:

    Java代码 复制代码 收藏代码 
    @Override
      public double pdf(VectorWritable vw) {
        return 1 / (1 + measure.distance(vw.get(), getCenter()));
      }
     
    每一次迭代后,就会重新计算一次centroid,通过AbstractCluster.computeParameters来计算的。
     

    pdfs.zSum()是pdfs double数组的和。然后再对pdfs进行归一化处理。

    因此最后select()用于选出相似度最大的cluster的下标,并且对其赋予权重1.0。如下所示:

     
    @Override
      public Vector select(Vector probabilities) {
        int maxValueIndex = probabilities.maxValueIndex();
        Vector weights = new SequentialAccessSparseVector(probabilities.size());
        weights.set(maxValueIndex, 1.0);
        return weights;
      }
     

    2.1.2. ClusterClassifier 为求Cluster新形心做准备

     接下来,为了重新得到新的中心,通过org.apache.mahout.clustering.classify.ClusterClassifier.train(int, Vector, double)为训练数据,即最后在AbstractCluster里面准备数据。

     
    public void observe(Vector x, double weight) {
        if (weight == 1.0) {
          observe(x);
        } else {
          setS0(getS0() + weight);
          Vector weightedX = x.times(weight);
          if (getS1() == null) {
            setS1(weightedX);
          } else {
            getS1().assign(weightedX, Functions.PLUS);
          }
          Vector x2 = x.times(x).times(weight);
          if (getS2() == null) {
            setS2(x2);
          } else {
            getS2().assign(x2, Functions.PLUS);
          }
        }
      }

    2.2. Reduce阶段

    在CIReducer里面,对属于同一个Cluster里面的数据进行合并,并且求出centroid形心。

     
    @Override
      protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException,
          InterruptedException {
        Iterator<ClusterWritable> iter = values.iterator();
        Cluster first = iter.next().getValue(); // there must always be at least one
        while (iter.hasNext()) {
          Cluster cluster = iter.next().getValue();
          first.observe(cluster);
        }
        List<Cluster> models = Lists.newArrayList();
        models.add(first);
        classifier = new ClusterClassifier(models, policy);
        classifier.close();
        context.write(key, new ClusterWritable(first));
      }

    2.2.1. Reduce中求centroid形心的算法

     求centroid算法代码如下:

    @Override
      public void computeParameters() {
        if (getS0() == 0) {
          return;
        }
        setNumObservations((long) getS0());
        setTotalObservations(getTotalObservations() + getNumObservations());
        setCenter(getS1().divide(getS0()));
        // compute the component stds
        if (getS0() > 1) {
          setRadius(getS2().times(getS0()).minus(getS1().times(getS1())).assign(new SquareRootFunction()).divide(getS0()));
        }
        setS0(0);
        setS1(center.like());
        setS2(center.like());
      }

    3. 聚类数据

    真正对output/data记录分配给各个簇的代码是:

     
     private static void classifyClusterMR(Configuration conf, Path input, Path clustersIn, Path output,
          Double clusterClassificationThreshold, boolean emitMostLikely) throws IOException, InterruptedException,
          ClassNotFoundException {
        
        conf.setFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD,
                      clusterClassificationThreshold.floatValue());
        conf.setBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, emitMostLikely);
        conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, clustersIn.toUri().toString());
        
        Job job = new Job(conf, "Cluster Classification Driver running over input: " + input);
        job.setJarByClass(ClusterClassificationDriver.class);
        
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        //进行记录分配
        job.setMapperClass(ClusterClassificationMapper.class);
        job.setNumReduceTasks(0);
        
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(WeightedVectorWritable.class);
        
        FileInputFormat.addInputPath(job, input);
        FileOutputFormat.setOutputPath(job, output);
        if (!job.waitForCompletion(true)) {
          throw new InterruptedException("Cluster Classification Driver Job failed processing " + input);
        }
      }
      

     摘录地址:http://zcdeng.iteye.com/blog/1859711

  • 相关阅读:
    JAVA共通関数入力パラメータをダブルクォートで囲む
    JAVA共通関数 システム時刻を取得する(2)
    JAVA共通関数ーー日付変換(YYYYMMDD → YYYY年MM月DD日)を行う
    JAVA共通関数文字変換 & " をHTML用に変換する(改行はタグで置き換え)
    JAVA共通関数日付変換(YYYYMMDD → YYYY/MM/DD)を行う
    JAVA共通関数--カンマ削除(数値からカンマを取り除く)を行う
    JAVASCRIPT共通関数デジタル時計を表示する
    JAVASCRIPT共通関数フォームの飛び先変更
    JAVA共通関数ーー数値フォーマット(数値をカンマ付きに編集)を行う
    VSCode linux下正确支持查找引用 "Find all reference"
  • 原文地址:https://www.cnblogs.com/anny-1980/p/3673434.html
Copyright © 2011-2022 走看看