zoukankan      html  css  js  c++  java
  • 文本分类——NaiveBayes

    前面文章已经介绍了朴素贝叶斯算法的原理,这里基于NavieBayes算法对newsgroup文本进行分类測试。

    文中代码參考:http://blog.csdn.net/jiangliqing1234/article/details/39642757

    主要内容例如以下:

    1、newsgroup数据集介绍

    数据下载地址:http://download.csdn.net/detail/hjy321686/8057761。

      文本中包括20个不同的新闻组,除当中少数文本属于多个新闻组以外,其余的文档都仅仅属于一个新闻组。

    2、newsgroup数据预处理

    要对文本进行分类,首先要对其进行预处理,预处理主要步骤例如以下:

    step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,全部大写字母转换成小写,可用正則表達式:String res[] = line.split("[^a-zA-Z]");

    step2:去停用词。过滤对别无价值的词

    step3:词根还原stemmer,基于Porter算法

    预处理类例如以下:

    package com.datamine.NaiveBayes;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.FileWriter;
    import java.util.ArrayList;
    
    /**
     * Newsgroup文档预处理
     * step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,全部大写字母转换成小写。可用正則表達式:String res[] = line.split("[^a-zA-Z]");
     * step2:去停用词,过滤对分类无价值的词
     * step3:词根还原stemmer。基于Porter算法
     * @author Administrator
     *
     */
    public class DataPreProcess {
    
    	private static ArrayList<String> stopWordsArray = new ArrayList<String>();
    	
    	/**
    	 * 输入文件的路径。处理数据
    	 * @param srcDir 文件文件夹的绝对路径
    	 * @param desDir 清洗后的文件路径
    	 * @throws Exception
    	 */
    	public void doProcess(String srcDir) throws Exception{
    		
    		File fileDir = new File(srcDir);
    		if(!fileDir.exists()){
    			System.out.println("文件不存在!");
    			return ;
    		}
    		
    		String subStrDir = srcDir.substring(srcDir.lastIndexOf('/'));
    		String dirTarget = srcDir+"/../../processedSample"+subStrDir;
    		File fileTarget = new File(dirTarget);
    		
    		if(!fileTarget.exists()){
    			//注意processedSample须要先建立文件夹建出来,否则会报错,由于母文件夹不存在
    			boolean mkdir = fileTarget.mkdir();
    		}
    		
    		File[] srcFiles = fileDir.listFiles();
    		
    		for(int i =0 ;i <srcFiles.length;i++){
    			
    			String fileFullName = srcFiles[i].getCanonicalPath(); //CanonicalPath不可是全路径,并且把..或者.这种符号解析出来。

    String fileShortName = srcFiles[i].getName(); //文件名称 if(!new File(fileFullName).isDirectory()){ //确认子文件名称不是文件夹,假设是能够再次递归调用 System.out.println("開始预处理:"+fileFullName); StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(dirTarget+"/"+fileShortName); createProcessFile(fileFullName,stringBuilder.toString()); }else{ fileFullName = fileFullName.replace("\", "/"); doProcess(fileFullName); } } } /** * 进行文本预处理生成目标文件 * @param srcDir 源文件文件文件夹的绝对路径 * @param targetDir 生成目标文件的绝对路径 * @throws Exception */ private void createProcessFile(String srcDir, String targetDir) throws Exception { FileReader srcFileReader = new FileReader(srcDir); FileWriter targetFileWriter = new FileWriter(targetDir); BufferedReader srcFileBR = new BufferedReader(srcFileReader); String line,resLine; while((line = srcFileBR.readLine()) != null){ resLine = lineProcess(line); if(!resLine.isEmpty()){ //按行写。一行写一个单词 String[] tempStr = resLine.split(" "); for(int i =0; i<tempStr.length ;i++){ if(!tempStr[i].isEmpty()) targetFileWriter.append(tempStr[i]+" "); } } } targetFileWriter.flush(); targetFileWriter.close(); srcFileReader.close(); srcFileBR.close(); } /** * 对每行字符串进行处理,主要是词法分析、去停用词和stemming(去除时态) * @param line 待处理的一行字符串 * @param stopWordsArray 停用词数组 * @return String 处理好的一行字符串,是由处理好的单词又一次生成,以空格为分隔符 */ private String lineProcess(String line) { /* * step1 * 英文词法分析,去除数字、连字符、标点符号、特殊字符, * 全部大写字符转换成小写,能够考虑使用正則表達式 */ String res[] = line.split("[^a-zA-Z]"); //step2 去停用词,大写转换成小写 //step3 Stemmer.run() String resString = new String(); for(int i=0;i<res.length;i++){ if(!res[i].isEmpty() && !stopWordsArray.contains(res[i].toLowerCase())) resString += " " + Stemmer.run(res[i].toLowerCase()) + " "; } return resString; } /** * 用stopWordsArray构造停用词的ArrayList容器 * @param stopwordsPath * @throws Exception */ private static void stopWordsToArray(String stopwordsPath) throws Exception { FileReader stopWordsReader = new FileReader(stopwordsPath); BufferedReader stopWordsBR = new BufferedReader(stopWordsReader); String stopWordsLine = null; //用stopWordsArray构造停用词的ArrayList容器 while((stopWordsLine = stopWordsBR.readLine()) != null){ if(!stopWordsLine.isEmpty()) stopWordsArray.add(stopWordsLine); } stopWordsReader.close(); stopWordsBR.close(); } public static void main(String[] args) throws Exception{ DataPreProcess dataPrePro = new DataPreProcess(); String srcDir = "E:/DataMiningSample/orginSample"; String stopwordsPath = "E:/DataMiningSample/stopwords.txt"; stopWordsToArray(stopwordsPath); dataPrePro.doProcess(srcDir); } }


    对于step3中的Porter算法能够网上下载,这里我基于其之上加入了一个run()方法。

    	/**
    	 * Stemmer中接口,将传入的word进行词根还原
    	 * @param word 传入单词
    	 * @return result 处理后的单词
    	 */
    	public static String run(String word){
    		
    		Stemmer s = new Stemmer();
    		
    		char[] ch = word.toCharArray();
    		
    		for (int c = 0; c < ch.length; c++)
    			s.add(ch[c]);
    		
    		s.stem();
    		{
    			String u;
    			u = s.toString();
    			//System.out.print(u);
    			return u;
    		}
    		
    	}

    3、特征项选择

    方法一:保留全部词作为特征词

    方法二:选取出现频率大于某一个数(3或者其它)的词作为特征词

    方法三:计算每一个词的权重tf*idf,依据权重来选取特征词

    本文中选取方法二。

    4、文本向量化

    因为本文中。特征词选择採用的是方法二,能够不用对文本进行向量化,可是统计特征词出现的次数方法写在ComputeWordsVector类中,为了程序执行这里还是把文本向量化的代码贴出来。后面使用KNN算法的时候也是要用到此类的。

    package com.datamine.NaiveBayes;
    
    import java.io.*;
    import java.util.*;
    
    /**
     * 计算文档的属性向量。将全部文档向量化
     * @author Administrator
     */
    public class ComputeWordsVector {
    
    	/**
    	 * 计算文档的TF属性向量。TFPerDocMap
    	 * 计算TF*IDF
    	 * @param strDir 处理好的newsgroup文件文件夹的绝对路径
    	 * @param trainSamplePercent 训练样本集占每一个类目的比例
    	 * @param indexOfSample 測试例子集的起始的測试例子编号      凝视:通过这个參数能够将文本分成训练和測试两部分
    	 * @param iDFPerWordMap  每一个词的IDF权值属性向量
    	 * @param wordMap 属性词典map
    	 * @throws IOException 
    	 */
    	public void computeTFMultiIDF(String strDir,double trainSamplePercent,int indexOfSample,
    			Map<String, Double> iDFPerWordMap,Map<String,Double> wordMap) throws IOException{
    		
    		File fileDir = new File(strDir);
    		String word;
    		SortedMap<String,Double> TFPerDocMap = new TreeMap<String, Double>();
    		//注意能够用两个写文件,一个专门写測试例子,一个专门写训练例子,用sampleType的值来表示
    		String trainFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample;
    		String testFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample;
    		
    		FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir)); //往训练文件里写
    		FileWriter tsTestWriter = new FileWriter(new File(testFileDir)); //往測试文件里写
    		
    		FileWriter tsWriter = null;
    		File[] sampleDir = fileDir.listFiles();
    		
    		for(int i = 0;i<sampleDir.length;i++){
    			
    			String cateShortName = sampleDir[i].getName();
    			System.out.println("開始计算: " + cateShortName);
    			
    			File[] sample = sampleDir[i].listFiles();
    			//測试例子集起始文件序号
    			double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent));
    			//測试例子集的结束文件序号
    			double testEndIndex = (indexOfSample+1)*(sample.length*(1-trainSamplePercent));
    			System.out.println("文件名称_文件数 :" + sampleDir[i].getCanonicalPath()+"_"+sample.length);
    			System.out.println("训练数:"+sample.length*trainSamplePercent
    					+ " 測试文本開始下标:"+ testBeginIndex+" 測试文本结束下标:"+testEndIndex);
    			
    			for(int j =0;j<sample.length; j++){
    				
    				//计算TF,即每一个词在该文件里出现的频率
    				TFPerDocMap.clear();
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				String fileShortName = sample[j].getName();
    				Double wordSumPerDoc = 0.0;//计算每篇文档的总字数
    				while((word = samBR.readLine()) != null){
    					
    					if(!word.isEmpty() && wordMap.containsKey(word)){
    						wordSumPerDoc++;
    						if(TFPerDocMap.containsKey(word))
    							TFPerDocMap.put(word, TFPerDocMap.get(word)+1);
    						else
    							TFPerDocMap.put(word, 1.0);
    					}
    				}
    				samBR.close();
    				
    				/*
    				 * 遍历 TFPerDocMap,除以文档的总词数wordSumPerDoc 则得到TF
    				 * TF*IDF得到终于的特征权值,并输出到文件
    				 * 注意:測试例子和训练例子写入的文件不同
    				 */
    				if(j >= testBeginIndex && j <= testEndIndex)
    					tsWriter = tsTestWriter;
    				else
    					tsWriter = tsTrainWriter;
    				
    				Double wordWeight;
    				Set<Map.Entry<String, Double>> tempTF = TFPerDocMap.entrySet();
    				for(Iterator<Map.Entry<String, Double>> mt = tempTF.iterator();mt.hasNext();){
    					
    					Map.Entry<String, Double> me = mt.next();
    					
    					//因为计算IDF很耗时,3万多个词的属性词典初步预计须要25个小时,先尝试觉得全部词的IDF都是1的情况
    					//wordWeight = (me.getValue() / wordSumPerDoc) * iDFPerWordMap.get(me.getKey());
    					wordWeight = (me.getValue() / wordSumPerDoc) * 1.0;
    					TFPerDocMap.put(me.getKey(), wordWeight);
    				}
    				
    				tsWriter.append(cateShortName + " ");
    				tsWriter.append(fileShortName + " ");
    				Set<Map.Entry<String, Double>> tempTF2 = TFPerDocMap.entrySet();
    				for(Iterator<Map.Entry<String, Double>> mt = tempTF2.iterator();mt.hasNext();){
    					Map.Entry<String, Double> me = mt.next();
    					tsWriter.append(me.getKey() + " " + me.getValue()+" ");
    				}
    				tsWriter.append("
    ");
    				tsWriter.flush();
    				
    			}
    		}
    		tsTrainWriter.close();
    		tsTestWriter.close();
    		tsWriter.close();
    	}
    	
    	/**
    	 * 统计每一个词的总出现次数。返回出现次数大于3词的词汇构成终于的属性词典
    	 * @param strDir 处理好的newsgroup文件文件夹的绝对路径
    	 * @param wordMap 记录出现的每一个词构成的属性词典
    	 * @return newWordMap 返回出现次数大于3次的词汇构成终于的属性词典
    	 * @throws IOException
    	 */
    	public SortedMap<String, Double> countWords(String strDir,
    			Map<String, Double> wordMap) throws IOException {
    		
    		File sampleFile = new File(strDir);
    		File[] sample = sampleFile.listFiles();
    		String word;
    		
    		for(int i =0 ;i < sample.length;i++){
    			
    			if(!sample[i].isDirectory()){
    				FileReader samReader = new FileReader(sample[i]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				while((word = samBR.readLine()) != null){
    					if(!word.isEmpty() && wordMap.containsKey(word))
    						wordMap.put(word, wordMap.get(word)+1);
    					else
    						wordMap.put(word, 1.0);
    				}
    				samBR.close();
    			}else{
    				countWords(sample[i].getCanonicalPath(),wordMap);
    			}
    		}
    		
    		/*
    		 * 仅仅返回出现次数大于3的单词
    		 * 这里为了简单,应该独立一个函数。避免多次执行
    		 */
    		SortedMap<String,Double> newWordMap = new TreeMap<String, Double>();
    		Set<Map.Entry<String, Double>> allWords = wordMap.entrySet();
    		for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){
    			Map.Entry<String, Double> me = it.next();
    			if(me.getValue() > 2)
    				newWordMap.put(me.getKey(), me.getValue());
    		}
    		
    		System.out.println("newWordMap "+ newWordMap.size());
    		
    		return newWordMap;
    	}
    	
    	/**
    	 * 打印属性词典,到allDicWordCountMap.txt中
    	 * @param wordMap 属性词典
    	 * @throws IOException 
    	 */
    	public void printWordMap(Map<String, Double> wordMap) throws IOException{
    		
    		System.out.println("printWordMap:");
    		int countLine = 0;
    		File outPutFile = new File("E:/DataMiningSample/docVector/allDicWordCountMap.txt");
    		FileWriter outPutFileWriter = new FileWriter(outPutFile);
    		
    		Set<Map.Entry<String, Double>> allWords = wordMap.entrySet();
    		for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){
    			Map.Entry<String, Double> me = it.next();
    			outPutFileWriter.write(me.getKey()+" "+me.getValue()+"
    ");
    			countLine++;
    		}
    		outPutFileWriter.close();
    		System.out.println("WordMap size : " + countLine);
    	}
    	
    	/**
    	 * 词w在整个文档集合中的逆向文档频率idf (Inverse Document Frequency),
    	 * 即文档总数n与词w所出现文件数docs(w, D)比值的对数: idf = log(n / docs(w, D))
    	 * 计算IDF。即属性词典中每一个词在多少个文档中出现过
    	 * @param strDir 处理好的newsgroup文件文件夹的绝对路径
    	 * @param wordMap 属性词典
    	 * @return 单词的IDFMap
    	 * @throws IOException 
    	 */
    	public SortedMap<String,Double> computeIDF(String strDir,Map<String, Double> wordMap) throws IOException{
    		
    		File fileDir = new File(strDir);
    		String word;
    		SortedMap<String,Double> IDFPerWordMap = new TreeMap<String, Double>();
    		Set<Map.Entry<String, Double>> wordMapSet = wordMap.entrySet();
    		
    		for(Iterator<Map.Entry<String, Double>> it = wordMapSet.iterator();it.hasNext();){
    			Map.Entry<String, Double> pe = it.next();
    			Double countDoc = 0.0; //出现字典词的文本数
    			Double sumDoc = 0.0; //文本总数
    			String dicWord = pe.getKey();
    			File[] sampleDir = fileDir.listFiles();
    			
    			for(int i =0;i<sampleDir.length;i++){
    				
    				File[] sample = sampleDir[i].listFiles();
    				for(int j = 0;j<sample.length;j++){
    					
    					sumDoc++; //统计文本数
    					
    					FileReader samReader = new FileReader(sample[j]);
    					BufferedReader samBR = new BufferedReader(samReader);
    					boolean isExist = false;
    					while((word = samBR.readLine()) != null){
    						if(!word.isEmpty() && word.equals(dicWord)){
    							isExist = true;
    							break;
    						}
    					}
    					if(isExist)
    						countDoc++;
    					
    					samBR.close();
    				}
    			}
    			//计算单词的IDF
    			//double IDF = Math.log(sumDoc / countDoc) / Math.log(10);
    			double IDF = Math.log(sumDoc / countDoc);
    			IDFPerWordMap.put(dicWord, IDF);
    		}
    		return IDFPerWordMap;
    	}
    	
    	
    	
    	public static void main(String[] args) throws IOException {
    		
    		ComputeWordsVector wordsVector = new ComputeWordsVector();
    		
    		String strDir = "E:\DataMiningSample\processedSample";
    		Map<String, Double> wordMap = new TreeMap<String, Double>();
    		
    		//属性词典
    		Map<String, Double> newWordMap = new TreeMap<String, Double>();
    		
    		newWordMap = wordsVector.countWords(strDir,wordMap);
    		
    		//wordsVector.printWordMap(newWordMap);
    		//wordsVector.computeIDF(strDir, newWordMap);
    		
    		double trainSamplePercent = 0.8;
    		int indexOfSample = 1;
    		Map<String, Double> iDFPerWordMap = null;
    		
    		wordsVector.computeTFMultiIDF(strDir, trainSamplePercent, indexOfSample, iDFPerWordMap, newWordMap);
    		
    		//test();
    	}
    	
    	public static void test(){
    		
    		double sumDoc  = 18828.0;
    		double countDoc = 229.0;
    		
    		double IDF1 = Math.log(sumDoc / countDoc) / Math.log(10);
    		double IDF2 = Math.log(sumDoc / countDoc) ;
    		
    		System.out.println(IDF1);
    		System.out.println(IDF2);
    		
    		System.out.println(Math.log(10));
    	}
    	
    }
    

    5、对文本分为測试集和训练集

    按指定的比例(0.9或者0.8)对整个文本进行划分。測试集和训练集

    package com.datamine.NaiveBayes;
    
    import java.io.*;
    import java.util.*;
    
    
    public class CreateTrainAndTestSample {
    
    	
    	void filterSpecialWords() throws IOException{
    		
    		String word;
    		ComputeWordsVector cwv = new ComputeWordsVector();
    		String fileDir = "E:\DataMiningSample\processedSample";
    		SortedMap<String, Double> wordMap = new TreeMap<String, Double>();
    		
    		wordMap = cwv.countWords(fileDir, wordMap);
    		cwv.printWordMap(wordMap); //把wordMap输出到文件
    		
    		File[] sampleDir = new File(fileDir).listFiles();
    		for(int i = 0;i<sampleDir.length;i++){
    			
    			File[] sample = sampleDir[i].listFiles();
    			String targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName();
    			File targetDirFile = new File(targetDir);
    			if(!targetDirFile.exists()){
    				targetDirFile.mkdir();
    			}
    			
    			for(int j = 0; j<sample.length;j++){
    				
    				String fileShortName = sample[j].getName();
    				targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName;
    				FileWriter tgWriter = new FileWriter(targetDir);
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				while((word = samBR.readLine()) != null){
    					if(wordMap.containsKey(word))
    						tgWriter.append(word+"
    ");
    				}
    				tgWriter.flush();
    				tgWriter.close();
    				samBR.close();
    			}
    		}
    	}
    	
    	/**
    	 * 创建训练集和測试集
    	 * @param fileDir 预处理好的文件路径 E:DataMiningSampleprocessedSampleOnlySpecial
    	 * @param trainSamplePercent 训练集占的百分比0.8
    	 * @param indexOfSample 一个測试集计算规则  1
    	 * @param classifyResultFile 測试例子正确类目记录文件
    	 * @throws IOException
    	 */
    	void createTestSample(String fileDir,double trainSamplePercent,int indexOfSample,String classifyResultFile) throws IOException{
    		
    		String word,targetDir;
    		FileWriter crWriter = new FileWriter(classifyResultFile);//測试例子正确类目记录文件
    		File[] sampleDir = new File(fileDir).listFiles();
    		
    		for(int i =0;i<sampleDir.length;i++){
    			
    			File[] sample = sampleDir[i].listFiles();
    			double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent));
    			double testEndIndex = (indexOfSample + 1)*(sample.length*(1-trainSamplePercent));
    			
    			for(int j = 0;j<sample.length;j++){
    				
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				String fileShortName = sample[j].getName();
    				String subFileName = fileShortName;
    				
    				if(j > testBeginIndex && j < testEndIndex){
    					targetDir = "E:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName(); 
    					crWriter.append(subFileName + " "+sampleDir[i].getName()+"
    ");
    				}else{
    					targetDir = "E:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName();
    				}
    					
    				targetDir = targetDir.replace("\", "/");
    				File trainSamFile = new File(targetDir);
    				if(!trainSamFile.exists()){
    					trainSamFile.mkdir();
    				}
    				
    				targetDir += "/" + subFileName;
    				FileWriter tsWriter = new FileWriter(new File(targetDir));
    				while((word = samBR.readLine()) != null)
    					tsWriter.append(word+"
    ");
    				tsWriter.flush();
    				
    				tsWriter.close();
    				samBR.close();
    			}
    			
    		}
    		crWriter.close();
    	}
    	
    	
    	public static void main(String[] args) throws IOException {
    		
    		CreateTrainAndTestSample test = new CreateTrainAndTestSample();
    		
    		String fileDir = "E:/DataMiningSample/processedSampleOnlySpecial";
    		double trainSamplePercent=0.8;
    		int indexOfSample=1;
    		String classifyResultFile="E:/DataMiningSample/classifyResult";
    		
    		test.createTestSample(fileDir, trainSamplePercent, indexOfSample, classifyResultFile);
    		//test.filterSpecialWords();
    	}
    	
    	
    }
    

    6、朴素贝叶斯算法描写叙述和实现

    依据朴素贝叶斯公式,每一个測试例子属于某个类别的概率 =  全部測试例子包括特征词类条件概率P(tk|c)之积 * 先验概率P(c)
    在详细计算类条件概率和先验概率时,朴素贝叶斯分类器有两种模型:
    (1)多元分布模型( multinomial model )  –以单词为粒度。也就是说。考虑每一个文件中面反复出现多次的单词。注意多项分布事实上是从二项分布拓展出来的,假设採用多项分布模型,那么每一个单词表示变量就不再是二值变量(出现/不出现),而是每一个单词在文件中出现的次数
    类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+训练样本中不反复特征词总数)
    先验概率P(c)=类c下的单词总数/整个训练样本的单词总数
    (2)伯努利模型(Bernoulli model) –以文件为粒度,或者说是採用二项分布模型,伯努利实验即N次独立反复随机实验,仅仅考虑事件发生/不发生,所以每一个单词的表示变量是布尔型的
    类条件概率P(tk|c)=(类c下包括单词tk的文件数+1)/(类c下文件总数+2)
    先验概率P(c)=类c下文件总数/整个训练样本的文件总数
    本分类器选用多元分布模型计算。依据《Introduction to Information Retrieval 》,多元分布模型计算准确率更高
    贝叶斯算法的实现有下面注意点:
           (1) 计算概率用到了BigDecimal类实现随意精度计算
           (2) 用交叉验证法做十次分类实验,对准确率取平均值
           (3) 依据正确类目文件和分类结果文计算混淆矩阵而且输出
           (4) Map<String,Double> cateWordsProb key为“类目_单词”, value为该类目下该单词的出现次数。避免反复计算


    package com.datamine.NaiveBayes;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileNotFoundException;
    import java.io.FileReader;
    import java.io.FileWriter;
    import java.io.IOException;
    import java.math.BigDecimal;
    import java.util.Iterator;
    import java.util.Map;
    import java.util.Set;
    import java.util.SortedSet;
    import java.util.TreeMap;
    import java.util.TreeSet;
    import java.util.Vector;
    
    
    /**
     * 利用朴素贝叶斯算法对newsgroup文档集做分类,採用十组交叉測试取平均值
     * 採用多项式模型
     * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集总单词数)
     * @author Administrator
     */
    public class NaiveBayesianClassifier {
    
    	/**
    	 * 用朴素贝叶斯算法对測试文档集分类
    	 * @param trainDir 训练文档集文件夹
    	 * @param testDir  測试文档集文件夹
    	 * @param classifyResultFileNew 分类结果文件路径
    	 * @throws Exception 
    	 */
    	private void doProcess(String trainDir,String testDir,
    			String classifyResultFileNew) throws Exception{
    		
    		//保存训练集中每一个类别的总词数      <文件夹名。单词总数> category
    		Map<String,Double> cateWordsNum = new TreeMap<String, Double>();
    		//保存训练样本中每一个类别中每一个属性词的出现次数  <类目_单词,数目> 
    		Map<String,Double> cateWordsProb = new TreeMap<String, Double>();
    		
    		cateWordsNum = getCateWordsNum(trainDir);
    		cateWordsProb = getCateWordsProb(trainDir);
    		
    		double totalWordsNum = 0.0;//记录全部训练集的总词数
    		Set<Map.Entry<String, Double>> cateWordsNumSet = cateWordsNum.entrySet();
    		for(Iterator<Map.Entry<String, Double>> it = cateWordsNumSet.iterator();it.hasNext();){
    			Map.Entry<String, Double> me = it.next();
    			totalWordsNum += me.getValue();
    		}
    		
    		//以下開始读取測试例子做分类
    		Vector<String> testFileWords = new Vector<String>(); //測试样本全部词的容器
    		String word;
    		File[] testDirFiles = new File(testDir).listFiles();
    		FileWriter crWriter = new FileWriter(classifyResultFileNew);
    		for(int i =0;i<testDirFiles.length;i++){
    			
    			File[] testSample = testDirFiles[i].listFiles();
    			
    			for(int j =0;j<testSample.length;j++){
    				
    				testFileWords.clear();
    				FileReader spReader = new FileReader(testSample[j]);
    				BufferedReader spBR = new BufferedReader(spReader);
    				while((word = spBR.readLine()) != null){
    					testFileWords.add(word);
    				}
    				spBR.close();
    				//以下分别计算该測试例子属于二十个类别的概率
    				File[] trainDirFiles = new File(trainDir).listFiles();
    				BigDecimal maxP = new BigDecimal(0);
    				String bestCate = null;
    				
    				for(int k =0; k < trainDirFiles.length; k++){
    					
    					BigDecimal p = computeCateProb(trainDirFiles[k],testFileWords,cateWordsNum,totalWordsNum,cateWordsProb);
    					
    					if( k == 0){
    						maxP = p;
    						bestCate = trainDirFiles[k].getName();
    						continue;
    					}
    					if(p.compareTo(maxP) == 1){
    						maxP = p;
    						bestCate = trainDirFiles[k].getName();
    					}
    				}
    				crWriter.append(testSample[j].getName() + " " + bestCate + "
    ");
    				crWriter.flush();
    			}
    		}
    		crWriter.close();
    		
    	}
    
    	/**
    	 * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集中总单词数)
    	 * 计算某一个測试样本数据某个类别的概率 使用多项式模型
    	 * @param trainFile 该类别全部的训练样本所在的文件夹
    	 * @param testFileWords  该測试样本中的全部词构成的容器
    	 * @param cateWordsNum  记录每一个文件夹下单词的总数
    	 * @param totalWordsNum  全部训练样本的单词的总数
    	 * @param cateWordsProb  记录每一个文件夹中出现单词和次数
    	 * @return 返回该測试样本在该类别中的概率
    	 */
    	private BigDecimal computeCateProb(File trainFile, Vector<String> testFileWords,
    			Map<String, Double> cateWordsNum, double totalWordsNum, Map<String, Double> cateWordsProb) {
    		
    		BigDecimal probability = new BigDecimal(1);
    		double wordNumInCate = cateWordsNum.get(trainFile.getName());
    		BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate);
    		BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum);
    		
    		for(Iterator<String> it = testFileWords.iterator();it.hasNext();){
    			
    			String me = it.next();
    			String key = trainFile.getName()+"_"+me;
    			double testFileWordNumInCate;
    			if(cateWordsProb.containsKey(key))
    				testFileWordNumInCate = cateWordsProb.get(key);
    			else
    				testFileWordNumInCate = 0.0;
    			BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate);
    			
    			BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001)))
    					.divide(wordNumInCateBD.add(totalWordsNumBD), 10, BigDecimal.ROUND_CEILING);
    			probability = probability.multiply(xcProb);
    		}
    		// P =  P(tk|c)*P(C)
    		BigDecimal result = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING));
    		
    		return result;
    	}
    
    	/**
    	 * 统计某个类训练样本中每一个单词出现的次数
    	 * @param trainDir 训练样本集文件夹
    	 * @return cateWordsProb 用"类目_单词"来索引map,value就是该类目下该单词出现的次数
    	 * @throws Exception 
    	 */
    	private Map<String, Double> getCateWordsProb(String trainDir) throws Exception {
    
    		Map<String,Double> cateWordsProb = new TreeMap<String, Double>();
    		File sampleFile = new File(trainDir);
    		File[] sampleDir = sampleFile.listFiles();
    		String word;
    		
    		for(int i =0;i < sampleDir.length;i++){
    			
    			File[] sample = sampleDir[i].listFiles();
    			
    			for(int j =0;j<sample.length;j++){
    				
    				FileReader samReader = new FileReader(sample[j]);
    				BufferedReader samBR = new BufferedReader(samReader);
    				while((word = samBR.readLine()) != null){
    					String key = sampleDir[i].getName()+"_"+word;
    					if(cateWordsProb.containsKey(key))
    						cateWordsProb.put(key, cateWordsProb.get(key)+1);
    					else
    						cateWordsProb.put(key, 1.0);
    				}
    				samBR.close();
    			}
    		}
    		
    		return cateWordsProb;
    	}
    
    	/**
    	 * 获得每一个类目下的单词总数
    	 * @param trainDir 训练文档集文件夹
    	 * @return cateWordsNum <文件夹名,单词总数>的map
    	 * @throws IOException 
    	 */
    	private Map<String, Double> getCateWordsNum(String trainDir) throws IOException {
    
    		Map<String, Double> cateWordsNum = new TreeMap<String, Double>();
    		File[] sampleDir = new File(trainDir).listFiles();
    		
    		for(int i =0;i<sampleDir.length;i++){
    			
    			double count = 0;
    			File[] sample = sampleDir[i].listFiles();
    			
    			for(int j =0;j<sample.length;j++){
    				
    				FileReader spReader = new FileReader(sample[j]);
    				BufferedReader spBR = new BufferedReader(spReader);
    				while(spBR.readLine() != null){
    					count++;
    				}
    				spBR.close();
    			}
    			cateWordsNum.put(sampleDir[i].getName(), count);
    		}
    		
    		return cateWordsNum;
    	}
    	
    	/**
    	 * 依据正确类目文件和分类结果文件统计出准确率
    	 * @param classifyRightCate 正确类目文件     <文件名称。类别文件夹名>
    	 * @param classifyResultFileNew 分类结果文件     <文件名称,类别文件夹名>
    	 * @return 分类的准确率
    	 * @throws Exception 
    	 */
    	public double computeAccuracy(String classifyRightCate,
    			String classifyResultFileNew) throws Exception {
    		
    		Map<String,String> rightCate = new TreeMap<String, String>();
    		Map<String,String> resultCate = new TreeMap<String,String>();
    		rightCate = getMapFromResultFile(classifyRightCate);
    		resultCate = getMapFromResultFile(classifyResultFileNew);
    		
    		Set<Map.Entry<String, String>> resCateSet = resultCate.entrySet();
    		double rightCount = 0.0;
    		for(Iterator<Map.Entry<String, String>> it = resCateSet.iterator();it.hasNext();){
    			
    			Map.Entry<String, String> me = it.next();
    			if(me.getValue().equals(rightCate.get(me.getKey())))
    				rightCount++;
    		}
    		
    		computerConfusionMatrix(rightCate,resultCate);
    		
    		return rightCount / resultCate.size();
    	}
    	
    	/**
    	 * 依据正确类目文件和分类结果文件计算混淆矩阵并输出
    	 * @param rightCate 正确类目map 
    	 * @param resultCate 分类结果相应map
    	 */
    	private void computerConfusionMatrix(Map<String, String> rightCate,
    			Map<String, String> resultCate) {
    		
    		int[][] confusionMatrix = new int[20][20];
    		
    		//首先求出类目相应的数组索引
    		SortedSet<String> cateNames = new TreeSet<String>();
    		Set<Map.Entry<String, String>> rightCateSet = rightCate.entrySet();
    		for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){
    			
    			Map.Entry<String, String> me = it.next();
    			cateNames.add(me.getValue());
    		}
    		cateNames.add("rec.sport.baseball");//防止数少一个类目
    		String[] cateNamesArray = cateNames.toArray(new String[0]);
    		Map<String,Integer> cateNamesToIndex = new TreeMap<String, Integer>();
    		
    		for(int i =0;i<cateNamesArray.length;i++){
    			cateNamesToIndex.put(cateNamesArray[i], i);
    		}
    		
    		for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){
    			
    			Map.Entry<String, String> me = it.next();
    			confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++;
    		}
    		
    		//输出混淆矩阵
    		double[] hangSum = new double[20];
    		System.out.print("      ");
    		
    		for(int i=0;i<20;i++){
    			System.out.printf("%-6d",i);
    		}
    		
    		System.out.println("准确率");
    		
    		for(int i =0;i<20;i++){
    			System.out.printf("%-6d",i);
    			for(int j = 0;j<20;j++){
    				System.out.printf("%-6d",confusionMatrix[i][j]);
    				hangSum[i] += confusionMatrix[i][j];
    			}
    			System.out.printf("%-6f
    ",confusionMatrix[i][i]/hangSum[i]);
    		}
    		System.out.println();
    		
    	}
    
    	/**
    	 * 从结果文件里读取Map
    	 * @param file 类目文件
    	 * @return Map<String,String> 由<文件名称,类目名>保存的map
    	 * @throws Exception 
    	 */
    	private Map<String, String> getMapFromResultFile(String file) throws Exception {
    
    		File crFile = new File(file);
    		FileReader crReader = new FileReader(crFile);
    		BufferedReader crBR = new BufferedReader(crReader);
    		Map<String,String> res = new TreeMap<String, String>();
    		String[] s;
    		String line;
    		while((line = crBR.readLine()) != null){
    			s = line.split(" ");
    			res.put(s[0], s[1]);
    		}
    		return res;
    	}
    
    	public static void main(String[] args) throws Exception {
    		
    		CreateTrainAndTestSample ctt = new CreateTrainAndTestSample();
    		NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();
    		
    		//依据包括非特征词的文档集生成仅仅包括特征词的文档集到processedSampleOnlySpecial文件夹下
    		ctt.filterSpecialWords();
    		
    		double[] accuracyOfEveryExp  = new double[10];
    		double accuracyAvg,sum = 0;
    		
    
    		for(int i =0;i<10;i++){//用交叉验证法做十次分类实验。对准确率取平均值	
    			
    			String TrainDir = "E:/DataMiningSample/TrainSample"+i;
    			String TestDir = "E:/DataMiningSample/TestSample"+i;
    			String classifyRightCate = "E:/DataMiningSample/classifyRightCate"+i+".txt";
    			String classifyResultFileNew = "E:/DataMiningSample/classifyResultNew"+i+".txt";
    		
    			ctt.createTestSample("E:/DataMiningSample/processedSampleOnlySpecial", 0.8, i, classifyRightCate);
    			
    			nbClassifier.doProcess(TrainDir, TestDir, classifyResultFileNew);
    		
    			accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew);
    			
    			System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);
    			
    		}
    		
    		for(int i =0;i<10;i++)
    			sum += accuracyOfEveryExp[i];
    		accuracyAvg = sum/10;
    		
    		System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg);
    		
    	}
    
    	
    	
    }
    

    7、实验结果与说明

    这里仅仅列出第一次运行的结果:


    这里使用的多项式模型是经过改进的计算方法:改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不反复特征词总数),分子当tk没有出现时,仅仅加0.001,这样更加精确的描写叙述的词的统计分布规律

    8、算法改进

    为了进一步提高朴素贝叶斯算法的分类能够进行例如以下改进:

    1、优化特征词选取的方法,如方法三,或者更优方法

    2、改进多项式模型的类条件概率的计算公式(上面已经实现)


  • 相关阅读:
    【分享】管理的最高境界是简单
    建立市场化风险评估机制推进地方政府信用评级建设
    手游-神雕侠侣 85侠客纪攻略(已通关)
    使用git的分支功能实现定制功能摘取与组合的想法
    组内正则培训记录
    组内Linq培训记录
    一次代码重构记录
    git代码库误操作还原记录
    关于代码重构的开始
    如何管理高手、大牛?
  • 原文地址:https://www.cnblogs.com/cxchanpin/p/7140468.html
Copyright © 2011-2022 走看看