zoukankan      html  css  js  c++  java
  • 贝叶斯文本分类 java实现

      昨天实现了一个基于贝叶斯定理的的文本分类,贝叶斯定理假设特征属性(在文本中就是词汇)对待分类项的影响都是独立的,道理比较简单,在中文分类系统中,分类的准确性与分词系统的好坏有很大的关系,这段代码也是试验不同分词系统才顺手写的一个。 

      试验数据用的sogou实验室的文本分类样本,一共分为9个类别,每个类别文件夹下大约有2000篇文章。由于文本数据量确实较大,所以得想办法让每次训练的结果都能保存起来,以便于下次直接使用,我这里使用序列化的方式保存在硬盘。 


      训练代码如下:  

      1 /**
      2  * 训练器
      3  * 
      4  * <a href="http://my.oschina.net/arthor" target="_blank" rel="nofollow">@author</a>  duyf
      5  * 
      6  */
      7 class Train implements Serializable {
      8 
      9     /**
     10      * 
     11      */
     12     private static final long serialVersionUID = 1L;
     13 
     14     public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";
     15     // 训练集的位置
     16     private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";
     17 
     18     // 类别序号对应的实际名称
     19     private Map<String, String> classMap = new HashMap<String, String>();
     20 
     21     // 类别对应的txt文本数
     22     private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>();
     23 
     24     // 所有文本数
     25     private AtomicInteger actCount = new AtomicInteger(0);
     26 
     27     
     28 
     29     // 每个类别对应的词典和频数
     30     private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>();
     31 
     32     // 分词器
     33     private transient Participle participle;
     34 
     35     private static Train trainInstance = new Train();
     36 
     37     public static Train getInstance() {
     38         trainInstance = new Train();
     39 
     40         // 读取序列化在硬盘的本类对象
     41         FileInputStream fis;
     42         try {
     43             File f = new File(SERIALIZABLE_PATH);
     44             if (f.length() != 0) {
     45                 fis = new FileInputStream(SERIALIZABLE_PATH);
     46                 ObjectInputStream oos = new ObjectInputStream(fis);
     47                 trainInstance = (Train) oos.readObject();
     48                 trainInstance.participle = new IkParticiple();
     49             } else {
     50                 trainInstance = new Train();
     51             }
     52         } catch (Exception e) {
     53             e.printStackTrace();
     54         }
     55 
     56         return trainInstance;
     57     }
     58 
     59     private Train() {
     60         this.participle = new IkParticiple();
     61     }
     62 
     63     public String readtxt(String path) {
     64         BufferedReader br = null;
     65         StringBuilder str = null;
     66         try {
     67             br = new BufferedReader(new FileReader(path));
     68 
     69             str = new StringBuilder();
     70 
     71             String r = br.readLine();
     72 
     73             while (r != null) {
     74                 str.append(r);
     75                 r = br.readLine();
     76 
     77             }
     78 
     79             return str.toString();
     80         } catch (IOException ex) {
     81             ex.printStackTrace();
     82         } finally {
     83             if (br != null) {
     84                 try {
     85                     br.close();
     86                 } catch (IOException e) {
     87                     e.printStackTrace();
     88                 }
     89             }
     90             str = null;
     91             br = null;
     92         }
     93 
     94         return "";
     95     }
     96 
     97     /**
     98      * 训练数据
     99      */
    100     public void realTrain() {
    101         // 初始化
    102         classMap = new HashMap<String, String>();
    103         classP = new HashMap<String, Integer>();
    104         actCount.set(0);
    105         classWordMap = new HashMap<String, Map<String, Double>>();
    106 
    107         // classMap.put("C000007", "汽车");
    108         classMap.put("C000008", "财经");
    109         classMap.put("C000010", "IT");
    110         classMap.put("C000013", "健康");
    111         classMap.put("C000014", "体育");
    112         classMap.put("C000016", "旅游");
    113         classMap.put("C000020", "教育");
    114         classMap.put("C000022", "招聘");
    115         classMap.put("C000023", "文化");
    116         classMap.put("C000024", "军事");
    117 
    118         // 计算各个类别的样本数
    119         Set<String> keySet = classMap.keySet();
    120 
    121         // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df
    122         final Set<String> allWords = new HashSet<String>();
    123 
    124         // 存放每个类别的文件词汇内容
    125         final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>();
    126 
    127         for (String classKey : keySet) {
    128 
    129             Participle participle = new IkParticiple();
    130             Map<String, Double> wordMap = new HashMap<String, Double>();
    131             File f = new File(trainPath + File.separator + classKey);
    132             File[] files = f.listFiles(new FileFilter() {
    133 
    134                 @Override
    135                 public boolean accept(File pathname) {
    136                     if (pathname.getName().endsWith(".txt")) {
    137                         return true;
    138                     }
    139                     return false;
    140                 }
    141 
    142             });
    143 
    144             // 存储每个类别的文件词汇向量
    145             List<String[]> fileContent = new ArrayList<String[]>();
    146             if (files != null) {
    147                 for (File txt : files) {
    148                     String content = readtxt(txt.getAbsolutePath());
    149                     // 分词
    150                     String[] word_arr = participle.participle(content, false);
    151                     fileContent.add(word_arr);
    152                     // 统计每个词出现的个数
    153                     for (String word : word_arr) {
    154                         if (wordMap.containsKey(word)) {
    155                             Double wordCount = wordMap.get(word);
    156                             wordMap.put(word, wordCount + 1);
    157                         } else {
    158                             wordMap.put(word, 1.0);
    159                         }
    160                         
    161                     }
    162                 }
    163             }
    164 
    165             // 每个类别对应的词典和频数
    166             classWordMap.put(classKey, wordMap);
    167 
    168             // 每个类别的文章数目
    169             classP.put(classKey, files.length);
    170             actCount.addAndGet(files.length);
    171             classContentMap.put(classKey, fileContent);
    172 
    173         }
    174 
    175         
    176 
    177         
    178 
    179         // 把训练好的训练器对象序列化到本地 (空间换时间)
    180         FileOutputStream fos;
    181         try {
    182             fos = new FileOutputStream(SERIALIZABLE_PATH);
    183             ObjectOutputStream oos = new ObjectOutputStream(fos);
    184             oos.writeObject(this);
    185         } catch (Exception e) {
    186             e.printStackTrace();
    187         }
    188 
    189     }
    190 
    191     /**
    192      * 分类
    193      * 
    194      * @param text
    195      * <a href="http://my.oschina.net/u/556800" target="_blank" rel="nofollow">@return</a>  返回各个类别的概率大小
    196      */
    197     public Map<String, Double> classify(String text) {
    198         // 分词,并且去重
    199         String[] text_words = participle.participle(text, false);
    200 
    201         Map<String, Double> frequencyOfType = new HashMap<String, Double>();
    202         Set<String> keySet = classMap.keySet();
    203         for (String classKey : keySet) {
    204             double typeOfThis = 1.0;
    205             Map<String, Double> wordMap = classWordMap.get(classKey);
    206             for (String word : text_words) {
    207                 Double wordCount = wordMap.get(word);
    208                 int articleCount = classP.get(classKey);
    209 
    210                 /*
    211                  * Double wordidf = idfMap.get(word); if(wordidf==null){
    212                  * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); }
    213                  */
    214 
    215                 // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算
    216                 double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1))
    217                         : (wordCount / articleCount);
    218 
    219                 // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
    220                 // 当double无限小的时候会归为0,为了避免 *10
    221 
    222                 typeOfThis = typeOfThis * term_frequency * 10;
    223                 typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE
    224                         : typeOfThis);
    225                 // System.out.println(typeOfThis+" : "+term_frequency+" :
    226                 // "+actCount);
    227             }
    228 
    229             typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);
    230 
    231             // 此类别文章出现的概率
    232             double classOfAll = classP.get(classKey) / actCount.doubleValue();
    233 
    234             // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果
    235             frequencyOfType.put(classKey, typeOfThis * classOfAll);
    236         }
    237 
    238         return frequencyOfType;
    239     }
    240 
    241     public void pringAll() {
    242         Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap
    243                 .entrySet();
    244         for (Entry<String, Map<String, Double>> ent : classWordEntry) {
    245             System.out.println("类别: " + ent.getKey());
    246             Map<String, Double> wordMap = ent.getValue();
    247             Set<Entry<String, Double>> wordMapSet = wordMap.entrySet();
    248             for (Entry<String, Double> wordEnt : wordMapSet) {
    249                 System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());
    250             }
    251         }
    252     }
    253 
    254     public Map<String, String> getClassMap() {
    255         return classMap;
    256     }
    257 
    258     public void setClassMap(Map<String, String> classMap) {
    259         this.classMap = classMap;
    260     }
    261 
    262 }

      在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。 

    来源:http://my.oschina.net/duyunfei/blog/80283

  • 相关阅读:
    Longest Palindromic Substring
    PayPal MLSE job description
    Continuous Median
    Remove Duplicates From Linked List
    Valid IP Address
    Longest substring without duplication
    Largest range
    Subarray sort
    Multi String Search
    Suffix Trie Construction
  • 原文地址:https://www.cnblogs.com/94julia/p/3103115.html
Copyright © 2011-2022 走看看