最近在做一个大作业。搭建一个信息检索平台。用到了贝叶斯分类参考了洞庭散人大哥的技术博客
http://www.cnblogs.com/phinecos/archive/2008/10/21/1316044.html
但是,他的算法运行起来很慢,原因是IO操作过于频繁,而且有些IO操作是可以避免的。下面开始介绍我的贝叶斯分类算法实现。
采用分词器为河北理工大学吕震宇老师的SHARPICTCLAS 该分词器没有Lucene接口,自己实现Analyzer 和Tokenizer 类如下
ICTCLASAnalyzer
using System;
using System.Collections.Generic;
using System.Text;
using System.IO;
using Lucene.Net.Analysis;
using Lucene.Net.Analysis.Standard;
namespace Bayes
{
class ICTCLASAnalyzer:Analyzer
{
public static readonly System.String[] CHINESE_ENGLISH_STOP_WORDS = new string[400];
public string NoisePath = Environment.CurrentDirectory + "\\data\\stopwords.txt";
public ICTCLASAnalyzer()
{
StreamReader reader = new StreamReader(NoisePath, System.Text.Encoding.Default);
string noise = reader.ReadLine();
int i = 0;
while (!string.IsNullOrEmpty(noise)&&i<400)
{
CHINESE_ENGLISH_STOP_WORDS[i] = noise;
noise = reader.ReadLine();
i++;
}
}
/**//**//**//// Constructs a {@link StandardTokenizer} filtered by a {@link
/// StandardFilter}, a {@link LowerCaseFilter} and a {@link StopFilter}.
///
public override TokenStream TokenStream(System.String fieldName, System.IO.TextReader reader)
{
TokenStream result = new ICTCLASTokenizer(reader);
result = new StandardFilter(result);
result = new LowerCaseFilter(result);
result = new StopFilter(result, CHINESE_ENGLISH_STOP_WORDS);
return result;
}
}
}
using System.Collections.Generic;
using System.Text;
using System.IO;
using Lucene.Net.Analysis;
using Lucene.Net.Analysis.Standard;
namespace Bayes
{
class ICTCLASAnalyzer:Analyzer
{
public static readonly System.String[] CHINESE_ENGLISH_STOP_WORDS = new string[400];
public string NoisePath = Environment.CurrentDirectory + "\\data\\stopwords.txt";
public ICTCLASAnalyzer()
{
StreamReader reader = new StreamReader(NoisePath, System.Text.Encoding.Default);
string noise = reader.ReadLine();
int i = 0;
while (!string.IsNullOrEmpty(noise)&&i<400)
{
CHINESE_ENGLISH_STOP_WORDS[i] = noise;
noise = reader.ReadLine();
i++;
}
}
/**//**//**//// Constructs a {@link StandardTokenizer} filtered by a {@link
/// StandardFilter}, a {@link LowerCaseFilter} and a {@link StopFilter}.
///
public override TokenStream TokenStream(System.String fieldName, System.IO.TextReader reader)
{
TokenStream result = new ICTCLASTokenizer(reader);
result = new StandardFilter(result);
result = new LowerCaseFilter(result);
result = new StopFilter(result, CHINESE_ENGLISH_STOP_WORDS);
return result;
}
}
}
ICTCLASTokenizer
using System;
using System.Collections.Generic;
using System.Text;
using Lucene.Net.Analysis;
using Lucene.Net.Documents;
using Lucene.Net.Analysis.Standard;
using System.IO;
using SharpICTCLAS;
namespace Bayes
{
class ICTCLASTokenizer:Tokenizer
{
int nKind = 1;
List<WordResult[]> result;
int startIndex = 0;
int endIndex = 0;
int i = 1;
/**//**/
/**////
/// 待分词的句子
///
private string sentence;
/**//**/
/**//// Constructs a tokenizer for this Reader.
public ICTCLASTokenizer(System.IO.TextReader reader)
{
this.input = reader;
sentence = input.ReadToEnd();
sentence = sentence.Replace("\r\n", "");
string DictPath = Path.Combine(Environment.CurrentDirectory, "Data") + Path.DirectorySeparatorChar;
//Console.WriteLine("正在初始化字典库,请稍候");
WordSegment wordSegment = new WordSegment();
wordSegment.InitWordSegment(DictPath);
result = wordSegment.Segment(sentence, nKind);
}
/**//**/
/**//// 进行切词,返回数据流中下一个token或者数据流为空时返回null
///
public override Token Next()
{
Token token = null;
while (i < result[0].Length - 1)
{
string word = result[0][i].sWord;
endIndex = startIndex + word.Length - 1;
token = new Token(word, startIndex, endIndex);
startIndex = endIndex + 1;
i++;
return token;
}
return null;
}
}
}
using System.Collections.Generic;
using System.Text;
using Lucene.Net.Analysis;
using Lucene.Net.Documents;
using Lucene.Net.Analysis.Standard;
using System.IO;
using SharpICTCLAS;
namespace Bayes
{
class ICTCLASTokenizer:Tokenizer
{
int nKind = 1;
List<WordResult[]> result;
int startIndex = 0;
int endIndex = 0;
int i = 1;
/**//**/
/**////
/// 待分词的句子
///
private string sentence;
/**//**/
/**//// Constructs a tokenizer for this Reader.
public ICTCLASTokenizer(System.IO.TextReader reader)
{
this.input = reader;
sentence = input.ReadToEnd();
sentence = sentence.Replace("\r\n", "");
string DictPath = Path.Combine(Environment.CurrentDirectory, "Data") + Path.DirectorySeparatorChar;
//Console.WriteLine("正在初始化字典库,请稍候");
WordSegment wordSegment = new WordSegment();
wordSegment.InitWordSegment(DictPath);
result = wordSegment.Segment(sentence, nKind);
}
/**//**/
/**//// 进行切词,返回数据流中下一个token或者数据流为空时返回null
///
public override Token Next()
{
Token token = null;
while (i < result[0].Length - 1)
{
string word = result[0][i].sWord;
endIndex = startIndex + word.Length - 1;
token = new Token(word, startIndex, endIndex);
startIndex = endIndex + 1;
i++;
return token;
}
return null;
}
}
}
下面开始介绍我的实现:分为五个类: ChineseSpliter用于分词,ClassifyResult用于储存结果。MemoryTrainingDataManager,用于管理IO操作 FastNaiveBayesClassification 用于实现贝叶斯算法。和洞庭散人不同之处在于我的各个计算前向概率,条件概率,联合概率的函数写在了一个类里,而不是多个类,这样做的目的在于避免不必要的IO操作。
ClassifyResult
using System;
using System.Collections.Generic;
using System.Text;
namespace Bayes
{
class ClassifyResult
{
public string className;
public float score;
public ClassifyResult()
{
className = "";
score = 0;
}
}
}
using System.Collections.Generic;
using System.Text;
namespace Bayes
{
class ClassifyResult
{
public string className;
public float score;
public ClassifyResult()
{
className = "";
score = 0;
}
}
}
ChineseSpliter
using System;
using System.Collections.Generic;
using System.Text;
using System.IO;
using Lucene.Net.Analysis;
namespace Bayes
{
class ChineseSpliter
{ public string Split(string text,string splitToken)
{
StringBuilder sb = new StringBuilder();
Analyzer an = new ICTCLASAnalyzer();
//TokenStream ts = an.ReusableTokenStream("", new StringReader(text));
TokenStream ts = an.TokenStream("", new StringReader(text));
Lucene.Net.Analysis.Token token;
while ((token = ts.Next()) != null)
{
sb.Append(splitToken + token.TermText());
}
return sb.ToString().Substring(1);
}
public string[] GetTerms(string result, string spliter)
{
string[] terms = result.Split(new string[] { spliter }, StringSplitOptions.RemoveEmptyEntries);
return terms;
}
}
}
using System.Collections.Generic;
using System.Text;
using System.IO;
using Lucene.Net.Analysis;
namespace Bayes
{
class ChineseSpliter
{ public string Split(string text,string splitToken)
{
StringBuilder sb = new StringBuilder();
Analyzer an = new ICTCLASAnalyzer();
//TokenStream ts = an.ReusableTokenStream("", new StringReader(text));
TokenStream ts = an.TokenStream("", new StringReader(text));
Lucene.Net.Analysis.Token token;
while ((token = ts.Next()) != null)
{
sb.Append(splitToken + token.TermText());
}
return sb.ToString().Substring(1);
}
public string[] GetTerms(string result, string spliter)
{
string[] terms = result.Split(new string[] { spliter }, StringSplitOptions.RemoveEmptyEntries);
return terms;
}
}
}
MemoryTrainingDataManager
using System;
using System.Collections.Generic;
using System.Text;
using System.IO;
namespace Bayes
{
class MemoryTrainingDataManager
{ //调用 函数GetClassifications()获取类别子目录在磁盘中的储存位置,为公有成员变量 txtClassification赋值
//调用 GetTtotalFileCount() 获取总共的样本集文章数目,为公有成员变量 totalFileCount赋值
public String[] txtClassifications;//训练语料分类集合
private static String defaultPath = "F:\\TrainingSet";
public int totalFileCount;
public void GetClassifications()
{
this.txtClassifications = Directory.GetDirectories(defaultPath);
}
public int GetSubClassFileCount(string subclass)
{
string[] paths = Directory.GetFiles(subclass);
return paths.Length;
}
public void GetTotalFileCount()
{
int count = 0;
for (int i = 0; i < txtClassifications.Length; i++)
{
count += GetSubClassFileCount(txtClassifications[i]);
}
totalFileCount = count;
}
public string GetText(string filePath)
{
StreamReader sr = new StreamReader(filePath, Encoding.Default);
string text = sr.ReadToEnd();
sr.Close();
return text;
}
public void SetMainMemoryStructure(ref StoreClass sc ,string subclass)
{
string []paths=Directory.GetFiles(subclass);
sc.classificationName = subclass;
sc.classificationCount = paths.Length;
sc.strFileContentList = new string[sc.classificationCount];
for (int k = 0; k < paths.Length; k++)
{
sc.strFileContentList[k]=GetText(paths[k]);
}
}
public int GetKeyCountOfSubClass(string key, ref StoreClass sc)
{
int count = 0;
for (int i = 0; i < sc.classificationCount; i++)
{
if (sc.strFileContentList[i].Contains(key))
{
count++;
}
}
return count;
}
}
}
using System.Collections.Generic;
using System.Text;
using System.IO;
namespace Bayes
{
class MemoryTrainingDataManager
{ //调用 函数GetClassifications()获取类别子目录在磁盘中的储存位置,为公有成员变量 txtClassification赋值
//调用 GetTtotalFileCount() 获取总共的样本集文章数目,为公有成员变量 totalFileCount赋值
public String[] txtClassifications;//训练语料分类集合
private static String defaultPath = "F:\\TrainingSet";
public int totalFileCount;
public void GetClassifications()
{
this.txtClassifications = Directory.GetDirectories(defaultPath);
}
public int GetSubClassFileCount(string subclass)
{
string[] paths = Directory.GetFiles(subclass);
return paths.Length;
}
public void GetTotalFileCount()
{
int count = 0;
for (int i = 0; i < txtClassifications.Length; i++)
{
count += GetSubClassFileCount(txtClassifications[i]);
}
totalFileCount = count;
}
public string GetText(string filePath)
{
StreamReader sr = new StreamReader(filePath, Encoding.Default);
string text = sr.ReadToEnd();
sr.Close();
return text;
}
public void SetMainMemoryStructure(ref StoreClass sc ,string subclass)
{
string []paths=Directory.GetFiles(subclass);
sc.classificationName = subclass;
sc.classificationCount = paths.Length;
sc.strFileContentList = new string[sc.classificationCount];
for (int k = 0; k < paths.Length; k++)
{
sc.strFileContentList[k]=GetText(paths[k]);
}
}
public int GetKeyCountOfSubClass(string key, ref StoreClass sc)
{
int count = 0;
for (int i = 0; i < sc.classificationCount; i++)
{
if (sc.strFileContentList[i].Contains(key))
{
count++;
}
}
return count;
}
}
}
FastNaiveBayesClassification
using System;
using System.Collections.Generic;
using System.Text;
namespace Bayes
{
class FastNaiveBayesClassification
{
// public StoreClass memorystore=new StoreClass();
public MemoryTrainingDataManager mtdm=new MemoryTrainingDataManager();
private ChineseSpliter spliter = new ChineseSpliter();
private static float ZoomFactor = 10;
public FastNaiveBayesClassification()
{
mtdm.GetClassifications();
mtdm.GetTotalFileCount();
}
/// <summary>
/// Nc 表示属于c类的文本数,N表示总文件数
/// </summary>
/// <param name="Nc"></param>
/// <param name="N"></param>
/// <returns></returns>
public float CalculatePriorProbability(float Nc,float N)
{
float ret = 0F;
ret = Nc / N;
return ret;
}
/// <summary>
///
/// </summary>
/// <param name="NxC">某一类别中某一词频出现的文件数</param>
/// <param name="Nc">该类别文件总数</param>
/// <returns></returns>
public float CalculateConditionalProbability(float NxC, float Nc)
{
float M = 0F;
float ret = 0F;
ret = (NxC + 1) / (Nc + M + mtdm.txtClassifications.Length);
return ret;
}
public float CalculateJointProbability(float []NxC, float Nc, float N)
{
float ret = 1;
for (int i = 0; i < NxC.Length; i++)
{
ret *= CalculateConditionalProbability(NxC[i], Nc) * ZoomFactor;
}
ret = ret * CalculatePriorProbability(Nc, N) ;
return ret;
}
public string[] SplitTerms(string text)
{
//string result = tokenizer.TextSplit(text, "@@@");
// string[] terms = tokenizer.GetTerms(result, "@@@");
string result = spliter.Split(text, "@@@");
string[] terms = spliter.GetTerms(result, "@@@");
return terms;
}
public ClassifyResult Classify(string text)
{ int end=mtdm.txtClassifications.Length;
ClassifyResult[] results = new ClassifyResult[end];
for (int i = 0; i < end; i++)
{
results[i] = new ClassifyResult();
}
string[] terms = SplitTerms(text);
float N = mtdm.totalFileCount;
for (int i = 0; i < end; i++)
{
StoreClass sc = new StoreClass();
mtdm.SetMainMemoryStructure(ref sc, mtdm.txtClassifications[i]);
float Nc = sc.classificationCount;
float[] Nxc = new float[terms.Length];
for(int k=0;k<terms.Length;k++)
{
Nxc[k]=mtdm.GetKeyCountOfSubClass(terms[k],ref sc);
// Console.WriteLine("含有的关键词数量{0}",Nxc[k]);
}
results[i].score= CalculateJointProbability(Nxc, Nc, N);
results[i].className = sc.classificationName;
Console.WriteLine("类别{0},分数{1}", results[i].className, results[i].score);
}
//选择法排序
for (int m = 0; m < results.Length - 1; m++)
{
int k = m;
for (int n = m + 1; n < results.Length; n++)
{
if (results[n].score > results[k].score)
{
k = n;
}
}
if (k != m)
{
ClassifyResult temp = new ClassifyResult();
temp.score = results[k].score;
temp.className = results[k].className;
results[k].className = results[m].className;
results[k].score = results[m].score;
results[m].score = temp.score;
results[m].className = temp.className;
}
}
return results[0];
}
}
}
using System.Collections.Generic;
using System.Text;
namespace Bayes
{
class FastNaiveBayesClassification
{
// public StoreClass memorystore=new StoreClass();
public MemoryTrainingDataManager mtdm=new MemoryTrainingDataManager();
private ChineseSpliter spliter = new ChineseSpliter();
private static float ZoomFactor = 10;
public FastNaiveBayesClassification()
{
mtdm.GetClassifications();
mtdm.GetTotalFileCount();
}
/// <summary>
/// Nc 表示属于c类的文本数,N表示总文件数
/// </summary>
/// <param name="Nc"></param>
/// <param name="N"></param>
/// <returns></returns>
public float CalculatePriorProbability(float Nc,float N)
{
float ret = 0F;
ret = Nc / N;
return ret;
}
/// <summary>
///
/// </summary>
/// <param name="NxC">某一类别中某一词频出现的文件数</param>
/// <param name="Nc">该类别文件总数</param>
/// <returns></returns>
public float CalculateConditionalProbability(float NxC, float Nc)
{
float M = 0F;
float ret = 0F;
ret = (NxC + 1) / (Nc + M + mtdm.txtClassifications.Length);
return ret;
}
public float CalculateJointProbability(float []NxC, float Nc, float N)
{
float ret = 1;
for (int i = 0; i < NxC.Length; i++)
{
ret *= CalculateConditionalProbability(NxC[i], Nc) * ZoomFactor;
}
ret = ret * CalculatePriorProbability(Nc, N) ;
return ret;
}
public string[] SplitTerms(string text)
{
//string result = tokenizer.TextSplit(text, "@@@");
// string[] terms = tokenizer.GetTerms(result, "@@@");
string result = spliter.Split(text, "@@@");
string[] terms = spliter.GetTerms(result, "@@@");
return terms;
}
public ClassifyResult Classify(string text)
{ int end=mtdm.txtClassifications.Length;
ClassifyResult[] results = new ClassifyResult[end];
for (int i = 0; i < end; i++)
{
results[i] = new ClassifyResult();
}
string[] terms = SplitTerms(text);
float N = mtdm.totalFileCount;
for (int i = 0; i < end; i++)
{
StoreClass sc = new StoreClass();
mtdm.SetMainMemoryStructure(ref sc, mtdm.txtClassifications[i]);
float Nc = sc.classificationCount;
float[] Nxc = new float[terms.Length];
for(int k=0;k<terms.Length;k++)
{
Nxc[k]=mtdm.GetKeyCountOfSubClass(terms[k],ref sc);
// Console.WriteLine("含有的关键词数量{0}",Nxc[k]);
}
results[i].score= CalculateJointProbability(Nxc, Nc, N);
results[i].className = sc.classificationName;
Console.WriteLine("类别{0},分数{1}", results[i].className, results[i].score);
}
//选择法排序
for (int m = 0; m < results.Length - 1; m++)
{
int k = m;
for (int n = m + 1; n < results.Length; n++)
{
if (results[n].score > results[k].score)
{
k = n;
}
}
if (k != m)
{
ClassifyResult temp = new ClassifyResult();
temp.score = results[k].score;
temp.className = results[k].className;
results[k].className = results[m].className;
results[k].score = results[m].score;
results[m].score = temp.score;
results[m].className = temp.className;
}
}
return results[0];
}
}
}