zoukankan      html  css  js  c++  java
  • 文本分类——NaiveBayes

    前面文章已经介绍了朴素贝叶斯算法的原理,这里基于NavieBayes算法对newsgroup文本进行分类測试。

    文中代码參考:http://blog.csdn.net/jiangliqing1234/article/details/39642757

    主要内容例如以下:

    1、newsgroup数据集介绍

    数据下载地址:http://download.csdn.net/detail/hjy321686/8057761。

      文本中包括20个不同的新闻组,除当中少数文本属于多个新闻组以外,其余的文档都仅仅属于一个新闻组。

    2、newsgroup数据预处理

    要对文本进行分类,首先要对其进行预处理,预处理主要步骤例如以下:

    step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,全部大写字母转换成小写,可用正則表達式:String res[] = line.split("[^a-zA-Z]");

    step2:去停用词。过滤对别无价值的词

    step3:词根还原stemmer,基于Porter算法

    预处理类例如以下:

    package com.datamine.NaiveBayes;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.FileWriter;
    import java.util.ArrayList;
    
    /**
     * Newsgroup文档预处理
     * step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,全部大写字母转换成小写。可用正則表達式:String res[] = line.split("[^a-zA-Z]");
     * step2:去停用词,过滤对分类无价值的词
     * step3:词根还原stemmer。基于Porter算法
     * @author Administrator
     *
     */
    public class DataPreProcess {
    
    	private static ArrayList<String> stopWordsArray = new ArrayList<String>();
    	
    	/**
    	 * 输入文件的路径。处理数据
    	 * @param srcDir 文件文件夹的绝对路径
    	 * @param desDir 清洗后的文件路径
    	 * @throws Exception
    	 */
    	public void doProcess(String srcDir) throws Exception{
    		
    		File fileDir = new File(srcDir);
    		if(!fileDir.exists()){
    			System.out.println("文件不存在!");
    			return ;
    		}
    		
    		String subStrDir = srcDir.substring(srcDir.lastIndexOf('/'));
    		String dirTarget = srcDir+"/../../processedSample"+subStrDir;
    		File fileTarget = new File(dirTarget);
    		
    		if(!fileTarget.exists()){
    			//注意processedSample须要先建立文件夹建出来,否则会报错,由于母文件夹不存在
    			boolean mkdir = fileTarget.mkdir();
    		}
    		
    		File[] srcFiles = fileDir.listFiles();
    		
    		for(int i =0 ;i <srcFiles.length;i++){
    			
    			String fileFullName = srcFiles[i].getCanonicalPath(); //CanonicalPath不可是全路径,并且把..或者.这种符号解析出来。

    String fileShortName = srcFiles[i].getName(); //文件名称 if(!new File(fileFullName).isDirectory()){ //确认子文件名称不是文件夹,假设是能够再次递归调用 System.out.println("開始预处理:"+fileFullName); StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(dirTarget+"/"+fileShortName); createProcessFile(fileFullName,stringBuilder.toString()); }else{ fileFullName = fileFullName.replace("\", "/"); doProcess(fileFullName); } } } /** * 进行文本预处理生成目标文件 * @param srcDir 源文件文件文件夹的绝对路径 * @param targetDir 生成目标文件的绝对路径 * @throws Exception */ private void createProcessFile(String srcDir, String targetDir) throws Exception { FileReader srcFileReader = new FileReader(srcDir); FileWriter targetFileWriter = new FileWriter(targetDir); BufferedReader srcFileBR = new BufferedReader(srcFileReader); String line,resLine; while((line = srcFileBR.readLine()) != null){ resLine = lineProcess(line); if(!resLine.isEmpty()){ //按行写。一行写一个单词 String[] tempStr = resLine.split(" "); for(int i =0; i<tempStr.length ;i++){ if(!tempStr[i].isEmpty()) targetFileWriter.append(tempStr[i]+" "); } } } targetFileWriter.flush(); targetFileWriter.close(); srcFileReader.close(); srcFileBR.close(); } /** * 对每行字符串进行处理,主要是词法分析、去停用词和stemming(去除时态) * @param line 待处理的一行字符串 * @param stopWordsArray 停用词数组 * @return String 处理好的一行字符串,是由处理好的单词又一次生成,以空格为分隔符 */ private String lineProcess(String line) { /* * step1 * 英文词法分析,去除数字、连字符、标点符号、特殊字符, * 全部大写字符转换成小写,能够考虑使用正則表達式 */ String res[] = line.split("[^a-zA-Z]"); //step2 去停用词,大写转换成小写 //step3 Stemmer.run() String resString = new String(); for(int i=0;i<res.length;i++){ if(!res[i].isEmpty() && !stopWordsArray.contains(res[i].toLowerCase())) resString += " " + Stemmer.run(res[i].toLowerCase()) + " "; } return resString; } /** * 用stopWordsArray构造停用词的ArrayList容器 * @param stopwordsPath * @throws Exception */ private static void stopWordsToArray(String stopwordsPath) throws Exception { FileReader stopWordsReader = new FileReader(stopwordsPath); BufferedReader stopWordsBR = new BufferedReader(stopWordsReader); String stopWordsLine = null; //用stopWordsArray构造停用词的ArrayList容器 while((stopWordsLine = stopWordsBR.readLine()) != null){ if(!stopWordsLine.isEmpty()) stopWordsArray.add(stopWordsLine); } stopWordsReader.close(); stopWordsBR.close(); } public static void main(String[] args) throws Exception{ DataPreProcess dataPrePro = new DataPreProcess(); String srcDir = "E:/DataMiningSample/orginSample"; String stopwordsPath = "E:/DataMiningSample/stopwords.txt"; stopWordsToArray(stopwordsPath); dataPrePro.doProcess(srcDir); } }


    对于step3中的Porter算法能够网上下载,这里我基于其之上加入了一个run()方法。

    	/**
    	 * Stemmer中接口,将传入的word进行词根还原
    	 * @param word 传入单词
    	 * @return result 处理后的单词
    	 */
    	public static String run(String word){
    		
    		Stemmer s = new Stemmer();
    		
    		char[] ch = word.toCharArray();
    		
    		for (int c = 0; c < ch.length; c++)
    			s.add(ch[c]);
    		
    		s.stem();
    		{
    			String u;
    			u = s.toString();
    			//System.out.print(u);
    			return u;
    		}
    		
    	}

    3、特征项选择

    方法一:保留全部词作为特征词

    方法二:选取出现频率大于某一个数(3或者其它)的词作为特征词

    方法三:计算每一个词的权重tf*idf,依据权重来选取特征词

    本文中选取方法二。

    4、文本向量化

    因为本文中。特征词选择採用的是方法二,能够不用对文本进行向量化,可是统计特征词出现的次数方法写在ComputeWordsVector类中,为了程序执行这里还是把文本向量化的代码贴出来。后面使用KNN算法的时候也是要用到此类的。

    package com.datamine.NaiveBayes;
    
    import java.io.*;
    import java.util.*;
    
    /**
     * 计算文档的属性向量。将全部文档向量化
     * @author Administrator
     */
    public class ComputeWordsVector {
    
    	/**
    	 * 计算文档的TF属性向量。TFPerDocMap
    	 * 计算TF*IDF
    	 * @param strDir 处理好的newsgroup文件文件夹的绝对路径
    	 * @param trainSamplePercent 训练样本集占每一个类目的比例
    	 * @param indexOfSample 測试例子集的起始的測试例子编号      凝视:通过这个參数能够将文本分成训练和測试两部分
    	 * @param iDFPerWordMap  每一个词的IDF权值属性向量
    	 * @param wordMap 属性词典map
    	 * @throws IOException 
    	 */
    	public void computeTFMultiIDF(String strDir,double trainSamplePercent,int indexOfSample,
    			Map<String, Double> iDFPerWordMap,Map<String,Double> wordMap) throws IOException{
    		
    		File fileDir = new File(strDir);
    		String word;
    		SortedMap<String,Double> TFPerDocMap = new TreeMap<String, Double>();
    		//注意能够用两个写文件,一个专门写測试例子,一个专门写训练例子,用sampleType的值来表示
    		String trainFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample;
    		String testFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample;
    		
    		FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir)); //往训练文件里写
    		FileWriter tsTestWriter = new FileWriter(new File(testFileDir)); //往測试文件里写
    		
    		FileWriter tsWriter = null;
    		File[] sampleDir = fileDir.listFiles();
    		
    		for(int i = 0;i<sampleDir.length;i++){
    			
    			String cateShortName = sampleDir[i].getName();
    			System.out.println("開始计算: " + cateShortName);
    			
    			File[] sample = sampleDir[i].listFiles();
    			//測试例子集起始文件序号
    			double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent));
    			//測试例子集的结束文件序号
    			double testEndIndex = (indexOfSample+1)*(sample.length*(1-trainSamplePercent));
    			System.out.println("文件名称_文件数 :" + sampleDir[i].getCanonicalPath()+"_"+sample.length);
    			System.out.println("训练数:"+sample.length*trainSamplePercent
    					+ " 測试文本開始下标:"+ testBeginIndex+" 測试文本结束下标:"+testEndIndex);
    			
    			for(int j =0;j<sample.length; j++){
    				
    				//计算TF,即每一个词在该文件里出现的频率
    				TFPerDocMap.clear();
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				String fileShortName = sample[j].getName();
    				Double wordSumPerDoc = 0.0;//计算每篇文档的总字数
    				while((word = samBR.readLine()) != null){
    					
    					if(!word.isEmpty() && wordMap.containsKey(word)){
    						wordSumPerDoc++;
    						if(TFPerDocMap.containsKey(word))
    							TFPerDocMap.put(word, TFPerDocMap.get(word)+1);
    						else
    							TFPerDocMap.put(word, 1.0);
    					}
    				}
    				samBR.close();
    				
    				/*
    				 * 遍历 TFPerDocMap,除以文档的总词数wordSumPerDoc 则得到TF
    				 * TF*IDF得到终于的特征权值,并输出到文件
    				 * 注意:測试例子和训练例子写入的文件不同
    				 */
    				if(j >= testBeginIndex && j <= testEndIndex)
    					tsWriter = tsTestWriter;
    				else
    					tsWriter = tsTrainWriter;
    				
    				Double wordWeight;
    				Set<Map.Entry<String, Double>> tempTF = TFPerDocMap.entrySet();
    				for(Iterator<Map.Entry<String, Double>> mt = tempTF.iterator();mt.hasNext();){
    					
    					Map.Entry<String, Double> me = mt.next();
    					
    					//因为计算IDF很耗时,3万多个词的属性词典初步预计须要25个小时,先尝试觉得全部词的IDF都是1的情况
    					//wordWeight = (me.getValue() / wordSumPerDoc) * iDFPerWordMap.get(me.getKey());
    					wordWeight = (me.getValue() / wordSumPerDoc) * 1.0;
    					TFPerDocMap.put(me.getKey(), wordWeight);
    				}
    				
    				tsWriter.append(cateShortName + " ");
    				tsWriter.append(fileShortName + " ");
    				Set<Map.Entry<String, Double>> tempTF2 = TFPerDocMap.entrySet();
    				for(Iterator<Map.Entry<String, Double>> mt = tempTF2.iterator();mt.hasNext();){
    					Map.Entry<String, Double> me = mt.next();
    					tsWriter.append(me.getKey() + " " + me.getValue()+" ");
    				}
    				tsWriter.append("
    ");
    				tsWriter.flush();
    				
    			}
    		}
    		tsTrainWriter.close();
    		tsTestWriter.close();
    		tsWriter.close();
    	}
    	
    	/**
    	 * 统计每一个词的总出现次数。返回出现次数大于3词的词汇构成终于的属性词典
    	 * @param strDir 处理好的newsgroup文件文件夹的绝对路径
    	 * @param wordMap 记录出现的每一个词构成的属性词典
    	 * @return newWordMap 返回出现次数大于3次的词汇构成终于的属性词典
    	 * @throws IOException
    	 */
    	public SortedMap<String, Double> countWords(String strDir,
    			Map<String, Double> wordMap) throws IOException {
    		
    		File sampleFile = new File(strDir);
    		File[] sample = sampleFile.listFiles();
    		String word;
    		
    		for(int i =0 ;i < sample.length;i++){
    			
    			if(!sample[i].isDirectory()){
    				FileReader samReader = new FileReader(sample[i]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				while((word = samBR.readLine()) != null){
    					if(!word.isEmpty() && wordMap.containsKey(word))
    						wordMap.put(word, wordMap.get(word)+1);
    					else
    						wordMap.put(word, 1.0);
    				}
    				samBR.close();
    			}else{
    				countWords(sample[i].getCanonicalPath(),wordMap);
    			}
    		}
    		
    		/*
    		 * 仅仅返回出现次数大于3的单词
    		 * 这里为了简单,应该独立一个函数。避免多次执行
    		 */
    		SortedMap<String,Double> newWordMap = new TreeMap<String, Double>();
    		Set<Map.Entry<String, Double>> allWords = wordMap.entrySet();
    		for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){
    			Map.Entry<String, Double> me = it.next();
    			if(me.getValue() > 2)
    				newWordMap.put(me.getKey(), me.getValue());
    		}
    		
    		System.out.println("newWordMap "+ newWordMap.size());
    		
    		return newWordMap;
    	}
    	
    	/**
    	 * 打印属性词典,到allDicWordCountMap.txt中
    	 * @param wordMap 属性词典
    	 * @throws IOException 
    	 */
    	public void printWordMap(Map<String, Double> wordMap) throws IOException{
    		
    		System.out.println("printWordMap:");
    		int countLine = 0;
    		File outPutFile = new File("E:/DataMiningSample/docVector/allDicWordCountMap.txt");
    		FileWriter outPutFileWriter = new FileWriter(outPutFile);
    		
    		Set<Map.Entry<String, Double>> allWords = wordMap.entrySet();
    		for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){
    			Map.Entry<String, Double> me = it.next();
    			outPutFileWriter.write(me.getKey()+" "+me.getValue()+"
    ");
    			countLine++;
    		}
    		outPutFileWriter.close();
    		System.out.println("WordMap size : " + countLine);
    	}
    	
    	/**
    	 * 词w在整个文档集合中的逆向文档频率idf (Inverse Document Frequency),
    	 * 即文档总数n与词w所出现文件数docs(w, D)比值的对数: idf = log(n / docs(w, D))
    	 * 计算IDF。即属性词典中每一个词在多少个文档中出现过
    	 * @param strDir 处理好的newsgroup文件文件夹的绝对路径
    	 * @param wordMap 属性词典
    	 * @return 单词的IDFMap
    	 * @throws IOException 
    	 */
    	public SortedMap<String,Double> computeIDF(String strDir,Map<String, Double> wordMap) throws IOException{
    		
    		File fileDir = new File(strDir);
    		String word;
    		SortedMap<String,Double> IDFPerWordMap = new TreeMap<String, Double>();
    		Set<Map.Entry<String, Double>> wordMapSet = wordMap.entrySet();
    		
    		for(Iterator<Map.Entry<String, Double>> it = wordMapSet.iterator();it.hasNext();){
    			Map.Entry<String, Double> pe = it.next();
    			Double countDoc = 0.0; //出现字典词的文本数
    			Double sumDoc = 0.0; //文本总数
    			String dicWord = pe.getKey();
    			File[] sampleDir = fileDir.listFiles();
    			
    			for(int i =0;i<sampleDir.length;i++){
    				
    				File[] sample = sampleDir[i].listFiles();
    				for(int j = 0;j<sample.length;j++){
    					
    					sumDoc++; //统计文本数
    					
    					FileReader samReader = new FileReader(sample[j]);
    					BufferedReader samBR = new BufferedReader(samReader);
    					boolean isExist = false;
    					while((word = samBR.readLine()) != null){
    						if(!word.isEmpty() && word.equals(dicWord)){
    							isExist = true;
    							break;
    						}
    					}
    					if(isExist)
    						countDoc++;
    					
    					samBR.close();
    				}
    			}
    			//计算单词的IDF
    			//double IDF = Math.log(sumDoc / countDoc) / Math.log(10);
    			double IDF = Math.log(sumDoc / countDoc);
    			IDFPerWordMap.put(dicWord, IDF);
    		}
    		return IDFPerWordMap;
    	}
    	
    	
    	
    	public static void main(String[] args) throws IOException {
    		
    		ComputeWordsVector wordsVector = new ComputeWordsVector();
    		
    		String strDir = "E:\DataMiningSample\processedSample";
    		Map<String, Double> wordMap = new TreeMap<String, Double>();
    		
    		//属性词典
    		Map<String, Double> newWordMap = new TreeMap<String, Double>();
    		
    		newWordMap = wordsVector.countWords(strDir,wordMap);
    		
    		//wordsVector.printWordMap(newWordMap);
    		//wordsVector.computeIDF(strDir, newWordMap);
    		
    		double trainSamplePercent = 0.8;
    		int indexOfSample = 1;
    		Map<String, Double> iDFPerWordMap = null;
    		
    		wordsVector.computeTFMultiIDF(strDir, trainSamplePercent, indexOfSample, iDFPerWordMap, newWordMap);
    		
    		//test();
    	}
    	
    	public static void test(){
    		
    		double sumDoc  = 18828.0;
    		double countDoc = 229.0;
    		
    		double IDF1 = Math.log(sumDoc / countDoc) / Math.log(10);
    		double IDF2 = Math.log(sumDoc / countDoc) ;
    		
    		System.out.println(IDF1);
    		System.out.println(IDF2);
    		
    		System.out.println(Math.log(10));
    	}
    	
    }
    

    5、对文本分为測试集和训练集

    按指定的比例(0.9或者0.8)对整个文本进行划分。測试集和训练集

    package com.datamine.NaiveBayes;
    
    import java.io.*;
    import java.util.*;
    
    
    public class CreateTrainAndTestSample {
    
    	
    	void filterSpecialWords() throws IOException{
    		
    		String word;
    		ComputeWordsVector cwv = new ComputeWordsVector();
    		String fileDir = "E:\DataMiningSample\processedSample";
    		SortedMap<String, Double> wordMap = new TreeMap<String, Double>();
    		
    		wordMap = cwv.countWords(fileDir, wordMap);
    		cwv.printWordMap(wordMap); //把wordMap输出到文件
    		
    		File[] sampleDir = new File(fileDir).listFiles();
    		for(int i = 0;i<sampleDir.length;i++){
    			
    			File[] sample = sampleDir[i].listFiles();
    			String targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName();
    			File targetDirFile = new File(targetDir);
    			if(!targetDirFile.exists()){
    				targetDirFile.mkdir();
    			}
    			
    			for(int j = 0; j<sample.length;j++){
    				
    				String fileShortName = sample[j].getName();
    				targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName;
    				FileWriter tgWriter = new FileWriter(targetDir);
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				while((word = samBR.readLine()) != null){
    					if(wordMap.containsKey(word))
    						tgWriter.append(word+"
    ");
    				}
    				tgWriter.flush();
    				tgWriter.close();
    				samBR.close();
    			}
    		}
    	}
    	
    	/**
    	 * 创建训练集和測试集
    	 * @param fileDir 预处理好的文件路径 E:DataMiningSampleprocessedSampleOnlySpecial
    	 * @param trainSamplePercent 训练集占的百分比0.8
    	 * @param indexOfSample 一个測试集计算规则  1
    	 * @param classifyResultFile 測试例子正确类目记录文件
    	 * @throws IOException
    	 */
    	void createTestSample(String fileDir,double trainSamplePercent,int indexOfSample,String classifyResultFile) throws IOException{
    		
    		String word,targetDir;
    		FileWriter crWriter = new FileWriter(classifyResultFile);//測试例子正确类目记录文件
    		File[] sampleDir = new File(fileDir).listFiles();
    		
    		for(int i =0;i<sampleDir.length;i++){
    			
    			File[] sample = sampleDir[i].listFiles();
    			double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent));
    			double testEndIndex = (indexOfSample + 1)*(sample.length*(1-trainSamplePercent));
    			
    			for(int j = 0;j<sample.length;j++){
    				
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				String fileShortName = sample[j].getName();
    				String subFileName = fileShortName;
    				
    				if(j > testBeginIndex && j < testEndIndex){
    					targetDir = "E:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName(); 
    					crWriter.append(subFileName + " "+sampleDir[i].getName()+"
    ");
    				}else{
    					targetDir = "E:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName();
    				}
    					
    				targetDir = targetDir.replace("\", "/");
    				File trainSamFile = new File(targetDir);
    				if(!trainSamFile.exists()){
    					trainSamFile.mkdir();
    				}
    				
    				targetDir += "/" + subFileName;
    				FileWriter tsWriter = new FileWriter(new File(targetDir));
    				while((word = samBR.readLine()) != null)
    					tsWriter.append(word+"
    ");
    				tsWriter.flush();
    				
    				tsWriter.close();
    				samBR.close();
    			}
    			
    		}
    		crWriter.close();
    	}
    	
    	
    	public static void main(String[] args) throws IOException {
    		
    		CreateTrainAndTestSample test = new CreateTrainAndTestSample();
    		
    		String fileDir = "E:/DataMiningSample/processedSampleOnlySpecial";
    		double trainSamplePercent=0.8;
    		int indexOfSample=1;
    		String classifyResultFile="E:/DataMiningSample/classifyResult";
    		
    		test.createTestSample(fileDir, trainSamplePercent, indexOfSample, classifyResultFile);
    		//test.filterSpecialWords();
    	}
    	
    	
    }
    

    6、朴素贝叶斯算法描写叙述和实现

    依据朴素贝叶斯公式,每一个測试例子属于某个类别的概率 =  全部測试例子包括特征词类条件概率P(tk|c)之积 * 先验概率P(c)
    在详细计算类条件概率和先验概率时,朴素贝叶斯分类器有两种模型:
    (1)多元分布模型( multinomial model )  –以单词为粒度。也就是说。考虑每一个文件中面反复出现多次的单词。注意多项分布事实上是从二项分布拓展出来的,假设採用多项分布模型,那么每一个单词表示变量就不再是二值变量(出现/不出现),而是每一个单词在文件中出现的次数
    类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+训练样本中不反复特征词总数)
    先验概率P(c)=类c下的单词总数/整个训练样本的单词总数
    (2)伯努利模型(Bernoulli model) –以文件为粒度,或者说是採用二项分布模型,伯努利实验即N次独立反复随机实验,仅仅考虑事件发生/不发生,所以每一个单词的表示变量是布尔型的
    类条件概率P(tk|c)=(类c下包括单词tk的文件数+1)/(类c下文件总数+2)
    先验概率P(c)=类c下文件总数/整个训练样本的文件总数
    本分类器选用多元分布模型计算。依据《Introduction to Information Retrieval 》,多元分布模型计算准确率更高
    贝叶斯算法的实现有下面注意点:
           (1) 计算概率用到了BigDecimal类实现随意精度计算
           (2) 用交叉验证法做十次分类实验,对准确率取平均值
           (3) 依据正确类目文件和分类结果文计算混淆矩阵而且输出
           (4) Map<String,Double> cateWordsProb key为“类目_单词”, value为该类目下该单词的出现次数。避免反复计算


    package com.datamine.NaiveBayes;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileNotFoundException;
    import java.io.FileReader;
    import java.io.FileWriter;
    import java.io.IOException;
    import java.math.BigDecimal;
    import java.util.Iterator;
    import java.util.Map;
    import java.util.Set;
    import java.util.SortedSet;
    import java.util.TreeMap;
    import java.util.TreeSet;
    import java.util.Vector;
    
    
    /**
     * 利用朴素贝叶斯算法对newsgroup文档集做分类,採用十组交叉測试取平均值
     * 採用多项式模型
     * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集总单词数)
     * @author Administrator
     */
    public class NaiveBayesianClassifier {
    
    	/**
    	 * 用朴素贝叶斯算法对測试文档集分类
    	 * @param trainDir 训练文档集文件夹
    	 * @param testDir  測试文档集文件夹
    	 * @param classifyResultFileNew 分类结果文件路径
    	 * @throws Exception 
    	 */
    	private void doProcess(String trainDir,String testDir,
    			String classifyResultFileNew) throws Exception{
    		
    		//保存训练集中每一个类别的总词数      <文件夹名。单词总数> category
    		Map<String,Double> cateWordsNum = new TreeMap<String, Double>();
    		//保存训练样本中每一个类别中每一个属性词的出现次数  <类目_单词,数目> 
    		Map<String,Double> cateWordsProb = new TreeMap<String, Double>();
    		
    		cateWordsNum = getCateWordsNum(trainDir);
    		cateWordsProb = getCateWordsProb(trainDir);
    		
    		double totalWordsNum = 0.0;//记录全部训练集的总词数
    		Set<Map.Entry<String, Double>> cateWordsNumSet = cateWordsNum.entrySet();
    		for(Iterator<Map.Entry<String, Double>> it = cateWordsNumSet.iterator();it.hasNext();){
    			Map.Entry<String, Double> me = it.next();
    			totalWordsNum += me.getValue();
    		}
    		
    		//以下開始读取測试例子做分类
    		Vector<String> testFileWords = new Vector<String>(); //測试样本全部词的容器
    		String word;
    		File[] testDirFiles = new File(testDir).listFiles();
    		FileWriter crWriter = new FileWriter(classifyResultFileNew);
    		for(int i =0;i<testDirFiles.length;i++){
    			
    			File[] testSample = testDirFiles[i].listFiles();
    			
    			for(int j =0;j<testSample.length;j++){
    				
    				testFileWords.clear();
    				FileReader spReader = new FileReader(testSample[j]);
    				BufferedReader spBR = new BufferedReader(spReader);
    				while((word = spBR.readLine()) != null){
    					testFileWords.add(word);
    				}
    				spBR.close();
    				//以下分别计算该測试例子属于二十个类别的概率
    				File[] trainDirFiles = new File(trainDir).listFiles();
    				BigDecimal maxP = new BigDecimal(0);
    				String bestCate = null;
    				
    				for(int k =0; k < trainDirFiles.length; k++){
    					
    					BigDecimal p = computeCateProb(trainDirFiles[k],testFileWords,cateWordsNum,totalWordsNum,cateWordsProb);
    					
    					if( k == 0){
    						maxP = p;
    						bestCate = trainDirFiles[k].getName();
    						continue;
    					}
    					if(p.compareTo(maxP) == 1){
    						maxP = p;
    						bestCate = trainDirFiles[k].getName();
    					}
    				}
    				crWriter.append(testSample[j].getName() + " " + bestCate + "
    ");
    				crWriter.flush();
    			}
    		}
    		crWriter.close();
    		
    	}
    
    	/**
    	 * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集中总单词数)
    	 * 计算某一个測试样本数据某个类别的概率 使用多项式模型
    	 * @param trainFile 该类别全部的训练样本所在的文件夹
    	 * @param testFileWords  该測试样本中的全部词构成的容器
    	 * @param cateWordsNum  记录每一个文件夹下单词的总数
    	 * @param totalWordsNum  全部训练样本的单词的总数
    	 * @param cateWordsProb  记录每一个文件夹中出现单词和次数
    	 * @return 返回该測试样本在该类别中的概率
    	 */
    	private BigDecimal computeCateProb(File trainFile, Vector<String> testFileWords,
    			Map<String, Double> cateWordsNum, double totalWordsNum, Map<String, Double> cateWordsProb) {
    		
    		BigDecimal probability = new BigDecimal(1);
    		double wordNumInCate = cateWordsNum.get(trainFile.getName());
    		BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate);
    		BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum);
    		
    		for(Iterator<String> it = testFileWords.iterator();it.hasNext();){
    			
    			String me = it.next();
    			String key = trainFile.getName()+"_"+me;
    			double testFileWordNumInCate;
    			if(cateWordsProb.containsKey(key))
    				testFileWordNumInCate = cateWordsProb.get(key);
    			else
    				testFileWordNumInCate = 0.0;
    			BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate);
    			
    			BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001)))
    					.divide(wordNumInCateBD.add(totalWordsNumBD), 10, BigDecimal.ROUND_CEILING);
    			probability = probability.multiply(xcProb);
    		}
    		// P =  P(tk|c)*P(C)
    		BigDecimal result = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING));
    		
    		return result;
    	}
    
    	/**
    	 * 统计某个类训练样本中每一个单词出现的次数
    	 * @param trainDir 训练样本集文件夹
    	 * @return cateWordsProb 用"类目_单词"来索引map,value就是该类目下该单词出现的次数
    	 * @throws Exception 
    	 */
    	private Map<String, Double> getCateWordsProb(String trainDir) throws Exception {
    
    		Map<String,Double> cateWordsProb = new TreeMap<String, Double>();
    		File sampleFile = new File(trainDir);
    		File[] sampleDir = sampleFile.listFiles();
    		String word;
    		
    		for(int i =0;i < sampleDir.length;i++){
    			
    			File[] sample = sampleDir[i].listFiles();
    			
    			for(int j =0;j<sample.length;j++){
    				
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				while((word = samBR.readLine()) != null){
    					String key = sampleDir[i].getName()+"_"+word;
    					if(cateWordsProb.containsKey(key))
    						cateWordsProb.put(key, cateWordsProb.get(key)+1);
    					else
    						cateWordsProb.put(key, 1.0);
    				}
    				samBR.close();
    			}
    		}
    		
    		return cateWordsProb;
    	}
    
    	/**
    	 * 获得每一个类目下的单词总数
    	 * @param trainDir 训练文档集文件夹
    	 * @return cateWordsNum <文件夹名,单词总数>的map
    	 * @throws IOException 
    	 */
    	private Map<String, Double> getCateWordsNum(String trainDir) throws IOException {
    
    		Map<String, Double> cateWordsNum = new TreeMap<String, Double>();
    		File[] sampleDir = new File(trainDir).listFiles();
    		
    		for(int i =0;i<sampleDir.length;i++){
    			
    			double count = 0;
    			File[] sample = sampleDir[i].listFiles();
    			
    			for(int j =0;j<sample.length;j++){
    				
    				FileReader spReader = new FileReader(sample[j]);
    				BufferedReader spBR = new BufferedReader(spReader);
    				while(spBR.readLine() != null){
    					count++;
    				}
    				spBR.close();
    			}
    			cateWordsNum.put(sampleDir[i].getName(), count);
    		}
    		
    		return cateWordsNum;
    	}
    	
    	/**
    	 * 依据正确类目文件和分类结果文件统计出准确率
    	 * @param classifyRightCate 正确类目文件     <文件名称。类别文件夹名>
    	 * @param classifyResultFileNew 分类结果文件     <文件名称,类别文件夹名>
    	 * @return 分类的准确率
    	 * @throws Exception 
    	 */
    	public double computeAccuracy(String classifyRightCate,
    			String classifyResultFileNew) throws Exception {
    		
    		Map<String,String> rightCate = new TreeMap<String, String>();
    		Map<String,String> resultCate = new TreeMap<String,String>();
    		rightCate = getMapFromResultFile(classifyRightCate);
    		resultCate = getMapFromResultFile(classifyResultFileNew);
    		
    		Set<Map.Entry<String, String>> resCateSet = resultCate.entrySet();
    		double rightCount = 0.0;
    		for(Iterator<Map.Entry<String, String>> it = resCateSet.iterator();it.hasNext();){
    			
    			Map.Entry<String, String> me = it.next();
    			if(me.getValue().equals(rightCate.get(me.getKey())))
    				rightCount++;
    		}
    		
    		computerConfusionMatrix(rightCate,resultCate);
    		
    		return rightCount / resultCate.size();
    	}
    	
    	/**
    	 * 依据正确类目文件和分类结果文件计算混淆矩阵并输出
    	 * @param rightCate 正确类目map 
    	 * @param resultCate 分类结果相应map
    	 */
    	private void computerConfusionMatrix(Map<String, String> rightCate,
    			Map<String, String> resultCate) {
    		
    		int[][] confusionMatrix = new int[20][20];
    		
    		//首先求出类目相应的数组索引
    		SortedSet<String> cateNames = new TreeSet<String>();
    		Set<Map.Entry<String, String>> rightCateSet = rightCate.entrySet();
    		for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){
    			
    			Map.Entry<String, String> me = it.next();
    			cateNames.add(me.getValue());
    		}
    		cateNames.add("rec.sport.baseball");//防止数少一个类目
    		String[] cateNamesArray = cateNames.toArray(new String[0]);
    		Map<String,Integer> cateNamesToIndex = new TreeMap<String, Integer>();
    		
    		for(int i =0;i<cateNamesArray.length;i++){
    			cateNamesToIndex.put(cateNamesArray[i], i);
    		}
    		
    		for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){
    			
    			Map.Entry<String, String> me = it.next();
    			confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++;
    		}
    		
    		//输出混淆矩阵
    		double[] hangSum = new double[20];
    		System.out.print("      ");
    		
    		for(int i=0;i<20;i++){
    			System.out.printf("%-6d",i);
    		}
    		
    		System.out.println("准确率");
    		
    		for(int i =0;i<20;i++){
    			System.out.printf("%-6d",i);
    			for(int j = 0;j<20;j++){
    				System.out.printf("%-6d",confusionMatrix[i][j]);
    				hangSum[i] += confusionMatrix[i][j];
    			}
    			System.out.printf("%-6f
    ",confusionMatrix[i][i]/hangSum[i]);
    		}
    		System.out.println();
    		
    	}
    
    	/**
    	 * 从结果文件里读取Map
    	 * @param file 类目文件
    	 * @return Map<String,String> 由<文件名称,类目名>保存的map
    	 * @throws Exception 
    	 */
    	private Map<String, String> getMapFromResultFile(String file) throws Exception {
    
    		File crFile = new File(file);
    		FileReader crReader = new FileReader(crFile);
    		BufferedReader crBR = new BufferedReader(crReader);
    		Map<String,String> res = new TreeMap<String, String>();
    		String[] s;
    		String line;
    		while((line = crBR.readLine()) != null){
    			s = line.split(" ");
    			res.put(s[0], s[1]);
    		}
    		return res;
    	}
    
    	public static void main(String[] args) throws Exception {
    		
    		CreateTrainAndTestSample ctt = new CreateTrainAndTestSample();
    		NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();
    		
    		//依据包括非特征词的文档集生成仅仅包括特征词的文档集到processedSampleOnlySpecial文件夹下
    		ctt.filterSpecialWords();
    		
    		double[] accuracyOfEveryExp  = new double[10];
    		double accuracyAvg,sum = 0;
    		
    
    		for(int i =0;i<10;i++){//用交叉验证法做十次分类实验。对准确率取平均值	
    			
    			String TrainDir = "E:/DataMiningSample/TrainSample"+i;
    			String TestDir = "E:/DataMiningSample/TestSample"+i;
    			String classifyRightCate = "E:/DataMiningSample/classifyRightCate"+i+".txt";
    			String classifyResultFileNew = "E:/DataMiningSample/classifyResultNew"+i+".txt";
    		
    			ctt.createTestSample("E:/DataMiningSample/processedSampleOnlySpecial", 0.8, i, classifyRightCate);
    			
    			nbClassifier.doProcess(TrainDir, TestDir, classifyResultFileNew);
    		
    			accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew);
    			
    			System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);
    			
    		}
    		
    		for(int i =0;i<10;i++)
    			sum += accuracyOfEveryExp[i];
    		accuracyAvg = sum/10;
    		
    		System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg);
    		
    	}
    
    	
    	
    }
    

    7、实验结果与说明

    这里仅仅列出第一次运行的结果:


    这里使用的多项式模型是经过改进的计算方法:改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不反复特征词总数),分子当tk没有出现时,仅仅加0.001,这样更加精确的描写叙述的词的统计分布规律

    8、算法改进

    为了进一步提高朴素贝叶斯算法的分类能够进行例如以下改进:

    1、优化特征词选取的方法,如方法三,或者更优方法

    2、改进多项式模型的类条件概率的计算公式(上面已经实现)


  • 相关阅读:
    POJ 1795 DNA Laboratory
    CodeForces 303B Rectangle Puzzle II
    HDU 2197 本源串
    HDU 5965 扫雷
    POJ 3099 Go Go Gorelians
    CodeForces 762D Maximum path
    CodeForces 731C Socks
    HDU 1231 最大连续子序列
    HDU 5650 so easy
    大话接口隐私与安全 转载
  • 原文地址:https://www.cnblogs.com/cxchanpin/p/7140468.html
Copyright © 2011-2022 走看看