zoukankan      html  css  js  c++  java
  • mapReduce编程之auto complete

    1 n-gram模型与auto complete

      n-gram模型是假设文本中一个词出现的概率只与它前面的N-1个词相关。auto complete的原理就是,根据用户输入的词,将后续出现概率较大的词组显示出来。因此我们可以基于n-gram模型来对用户的输入作预测。

      我们的实现方法是:首先用mapreduce在offline对语料库中的数据进行n-gram建模,存到数据库中。然后用户在输入的时候向数据库中查询,获取之后出现的概率较大的词,通过前端php脚本刷新实时显示在界面上。如下所示:

    2 mapReduce流程

      

    2.1 MR1

      mapper负责按句读入语料库中的数据,分别作2~Ngram的切分(1-gram在这里没用),发送给reducer。

      reducer则统计所有N-gram出现的次数。(这里就是一个wordcount)

    2.2 MR2

      mapper负责读入之前生成的N-gram及次数,将最后一个单词切分出来,以前面N-1个单词为key向reducer发送。

      reducer里面得到的就是N-gram概率模型,即已知前N-1个词组成的phrase,最后一个词出现的所有可能及其概率。这里我们不用计算概率,仍然沿用词频能达到相同的效果,因为auto complete关注的是概率之间的相对大小而不是概率值本身。这里我们选择出现概率最大的topk个词来存入数据库,可以用treemap或者priorityQueue来做。

        (注:这里的starting_word是1~n-1个词,following_word只能是一个词,因为这样才符合我们N-gram概率模型的意义。)

     2.3 如何预测后面n个单词

      数据库中的n-gram模型:

      如上所述,我们看出使用n-gram模型只能与预测下一个单词。为了预测结果的多样性,如果我们要预测之后的n个单词怎么做?

      使用sql语句,查询的时候查询匹配"input%"的所有starting_phrase,就可以实现。

    3 代码

     NGramLibraryBuilder.java

     1 import java.io.IOException;
     2 
     3 import org.apache.hadoop.conf.Configuration;
     4 import org.apache.hadoop.fs.Path;
     5 import org.apache.hadoop.io.IntWritable;
     6 import org.apache.hadoop.io.LongWritable;
     7 import org.apache.hadoop.io.Text;
     8 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
     9 import org.apache.hadoop.mapreduce.Job;
    10 import org.apache.hadoop.mapreduce.Mapper;
    11 import org.apache.hadoop.mapreduce.Reducer;
    12 import org.apache.hadoop.mapreduce.Mapper.Context;
    13 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
    14 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
    15 import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
    16 
    17 public class NGramLibraryBuilder {
    18     public static class NGramMapper extends Mapper<LongWritable, Text, Text, IntWritable> {
    19 
    20         int noGram;
    21         @Override
    22         public void setup(Context context) {
    23             Configuration conf = context.getConfiguration();
    24             noGram = conf.getInt("noGram", 5);
    25         }
    26 
    27         // map method
    28         @Override
    29         public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
    30             
    31             String line = value.toString();
    32             
    33             line = line.trim().toLowerCase();
    34             line = line.replaceAll("[^a-z]", " ");
    35             
    36             String[] words = line.split("\s+"); //split by ' ', '	'...ect
    37             
    38             if(words.length<2) {
    39                 return;
    40             }
    41             
    42             //I love big data
    43             StringBuilder sb;
    44             for(int i = 0; i < words.length-1; i++) {
    45                 sb = new StringBuilder();
    46                 sb.append(words[i]);
    47                 for(int j=1; i+j<words.length && j<noGram; j++) {
    48                     sb.append(" ");
    49                     sb.append(words[i+j]);
    50                     context.write(new Text(sb.toString().trim()), new IntWritable(1));
    51                 }
    52             }
    53         }
    54     }
    55 
    56     public static class NGramReducer extends Reducer<Text, IntWritable, Text, IntWritable> {
    57         // reduce method
    58         @Override
    59         public void reduce(Text key, Iterable<IntWritable> values, Context context)
    60                 throws IOException, InterruptedException {
    61             int sum = 0;
    62             for(IntWritable value: values) {
    63                 sum += value.get();
    64             }
    65             context.write(key, new IntWritable(sum));
    66         }
    67     }
    68 
    69 }
    View Code

    LanguageModel.java

      1 import java.io.IOException;
      2 import java.util.ArrayList;
      3 import java.util.Collections;
      4 import java.util.Iterator;
      5 import java.util.List;
      6 import java.util.Set;
      7 import java.util.TreeMap;
      8 
      9 import org.apache.hadoop.conf.Configuration;
     10 import org.apache.hadoop.fs.Path;
     11 import org.apache.hadoop.io.LongWritable;
     12 import org.apache.hadoop.io.NullWritable;
     13 import org.apache.hadoop.io.Text;
     14 import org.apache.hadoop.mapreduce.Job;
     15 import org.apache.hadoop.mapreduce.Mapper;
     16 import org.apache.hadoop.mapreduce.Reducer;
     17 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
     18 
     19 public class LanguageModel {
     20     public static class Map extends Mapper<LongWritable, Text, Text, Text> {
     21 
     22         int threashold;
     23         // get the threashold parameter from the configuration
     24         @Override
     25         public void setup(Context context) {
     26             Configuration conf = context.getConfiguration();
     27             threashold = conf.getInt("threashold", 20);
     28         }
     29 
     30         
     31         @Override
     32         public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
     33             if((value == null) || (value.toString().trim()).length() == 0) {
     34                 return;
     35             }
     36             //this is cool	20
     37             String line = value.toString().trim();
     38             
     39             String[] wordsPlusCount = line.split("	");
     40             if(wordsPlusCount.length < 2) {
     41                 return;
     42             }
     43             
     44             String[] words = wordsPlusCount[0].split("\s+");
     45             int count = Integer.valueOf(wordsPlusCount[1]);
     46             
     47             if(count < threashold) {
     48                 return;
     49             }
     50             
     51             //this is --> cool = 20
     52             StringBuilder sb = new StringBuilder();
     53             for(int i = 0; i < words.length-1; i++) {
     54                 sb.append(words[i]).append(" ");
     55             }
     56             String outputKey = sb.toString().trim();
     57             String outputValue = words[words.length - 1];
     58             
     59             if(!((outputKey == null) || (outputKey.length() <1))) {
     60                 context.write(new Text(outputKey), new Text(outputValue + "=" + count));
     61             }
     62         }
     63     }
     64 
     65     public static class Reduce extends Reducer<Text, Text, DBOutputWritable, NullWritable> {
     66 
     67         int n;
     68         // get the n parameter from the configuration
     69         @Override
     70         public void setup(Context context) {
     71             Configuration conf = context.getConfiguration();
     72             n = conf.getInt("n", 5);
     73         }
     74 
     75         @Override
     76         public void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
     77             
     78             //this is, <girl = 50, boy = 60>
     79             TreeMap<Integer, List<String>> tm = new TreeMap<Integer, List<String>>(Collections.reverseOrder());
     80             for(Text val: values) {
     81                 String curValue = val.toString().trim();
     82                 String word = curValue.split("=")[0].trim();
     83                 int count = Integer.parseInt(curValue.split("=")[1].trim());
     84                 if(tm.containsKey(count)) {
     85                     tm.get(count).add(word);
     86                 }
     87                 else {
     88                     List<String> list = new ArrayList<String>();
     89                     list.add(word);
     90                     tm.put(count, list);
     91                 }
     92             }
     93             //<50, <girl, bird>> <60, <boy...>>
     94             Iterator<Integer> iter = tm.keySet().iterator();
     95             for(int j=0; iter.hasNext() && j<n; j++) {
     96                 int keyCount = iter.next();
     97                 List<String> words = tm.get(keyCount);
     98                 for(String curWord: words) {
     99                     context.write(new DBOutputWritable(key.toString(), curWord, keyCount),NullWritable.get());
    100                     j++;
    101                 }
    102             }
    103         }
    104     }
    105 }
    View Code

    DBOutputWritable.java

     1 import java.sql.PreparedStatement;
     2 import java.sql.ResultSet;
     3 import java.sql.SQLException;
     4 
     5 import org.apache.hadoop.mapreduce.lib.db.DBWritable;
     6 
     7 public class DBOutputWritable implements DBWritable{
     8 
     9     private String starting_phrase;
    10     private String following_word;
    11     private int count;
    12     
    13     public DBOutputWritable(String starting_prhase, String following_word, int count) {
    14         this.starting_phrase = starting_prhase;
    15         this.following_word = following_word;
    16         this.count= count;
    17     }
    18 
    19     public void readFields(ResultSet arg0) throws SQLException {
    20         this.starting_phrase = arg0.getString(1);
    21         this.following_word = arg0.getString(2);
    22         this.count = arg0.getInt(3);
    23         
    24     }
    25 
    26     public void write(PreparedStatement arg0) throws SQLException {
    27         arg0.setString(1, starting_phrase);
    28         arg0.setString(2, following_word);
    29         arg0.setInt(3, count);
    30         
    31     }
    32 
    33 }
    View Code

    Driver.java

     1 import java.io.IOException;
     2 
     3 import org.apache.hadoop.conf.Configuration;
     4 import org.apache.hadoop.fs.Path;
     5 import org.apache.hadoop.io.IntWritable;
     6 import org.apache.hadoop.io.NullWritable;
     7 import org.apache.hadoop.io.Text;
     8 import org.apache.hadoop.mapreduce.Job;
     9 import org.apache.hadoop.mapreduce.lib.db.DBConfiguration;
    10 import org.apache.hadoop.mapreduce.lib.db.DBOutputFormat;
    11 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
    12 import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
    13 
    14 
    15 public class Driver {
    16 
    17     public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
    18         //job1
    19         Configuration conf1 = new Configuration();
    20         conf1.set("textinputformat.record.delimiter", ".");
    21         conf1.set("noGram", args[2]);
    22         
    23         Job job1 = Job.getInstance();
    24         job1.setJobName("NGram");
    25         job1.setJarByClass(Driver.class);
    26         
    27         job1.setMapperClass(NGramLibraryBuilder.NGramMapper.class);
    28         job1.setReducerClass(NGramLibraryBuilder.NGramReducer.class);
    29         
    30         job1.setOutputKeyClass(Text.class);
    31         job1.setOutputValueClass(IntWritable.class);
    32         
    33         job1.setInputFormatClass(TextInputFormat.class);
    34         job1.setOutputFormatClass(TextOutputFormat.class);
    35         
    36         TextInputFormat.setInputPaths(job1, new Path(args[0]));
    37         TextOutputFormat.setOutputPath(job1, new Path(args[1]));
    38         job1.waitForCompletion(true);
    39         
    40         //how to connect two jobs?
    41         // last output is second input
    42         
    43         //2nd job
    44         Configuration conf2 = new Configuration();
    45         conf2.set("threashold", args[3]);
    46         conf2.set("n", args[4]);
    47         
    48         DBConfiguration.configureDB(conf2, 
    49                 "com.mysql.jdbc.Driver",
    50                 "jdbc:mysql://ip_address:port/test",
    51                 "root",
    52                 "password");
    53         
    54         Job job2 = Job.getInstance(conf2);
    55         job2.setJobName("Model");
    56         job2.setJarByClass(Driver.class);
    57         
    58         job2.addArchiveToClassPath(new Path("path_to_ur_connector"));
    59         job2.setMapOutputKeyClass(Text.class);
    60         job2.setMapOutputValueClass(Text.class);
    61         job2.setOutputKeyClass(DBOutputWritable.class);
    62         job2.setOutputValueClass(NullWritable.class);
    63         
    64         job2.setMapperClass(LanguageModel.Map.class);
    65         job2.setReducerClass(LanguageModel.Reduce.class);
    66         
    67         job2.setInputFormatClass(TextInputFormat.class);
    68         job2.setOutputFormatClass(DBOutputFormat.class);
    69         
    70         DBOutputFormat.setOutput(job2, "output", 
    71                 new String[] {"starting_phrase", "following_word", "count"});
    72 
    73         TextInputFormat.setInputPaths(job2, args[1]);
    74         job2.waitForCompletion(true);
    75     }
    76 
    77 }
    View Code
  • 相关阅读:
    Eloquent中一些其他的create方法
    laravel入门教程
    Eloqument 学习
    python 进程间通信(上)
    为什么寄存器比内存快?
    记录linux 命令
    linux 服务器下的基本操作
    linux 制作U盘启动,和定制系统
    kali linux 安装 matlab2016Rb
    kali linux 安装virtualbox报错(rc=-1908)
  • 原文地址:https://www.cnblogs.com/coldyan/p/6081978.html
Copyright © 2011-2022 走看看