zoukankan      html  css  js  c++  java
  • 简单的贝叶斯分类器的python实现

      1 # -*- coding: utf-8 -*-
      2 '''
      3 >>> c = Classy()
      4 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture')
      5 True
      6 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices')
      7 True
      8 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture')
      9 True
     10 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair']
     11 >>> c.classify(my_office)
     12 ('input_devices', -1.0986122886681098)
     13 ...
     14 >>> c = Classy()
     15 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture')
     16 True
     17 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices')
     18 True
     19 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture')
     20 True
     21 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair']
     22 >>> c.classify(my_office)
     23 ('input_devices', -1.0986122886681098)
     24 ...
     25 '''
     26 
     27 from collections import Counter
     28 import math
     29 
     30 class ClassifierNotTrainedException(Exception):
     31     
     32     def __str__(self):
     33         return "Classifier is not trained."
     34 
     35 class Classy(object):
     36     
     37     def __init__(self):
     38         self.term_count_store = {}
     39         self.data = {
     40             'class_term_count': {},
     41             'beta_priors': {},
     42             'class_doc_count': {},
     43         }
     44         self.total_term_count = 0
     45         self.total_doc_count = 0
     46         
     47     def train(self, document_source, class_id):
     48     
     49         '''
     50         Trains the classifier.
     51         
     52         '''
     53         count = Counter(document_source)
     54         try:
     55             self.term_count_store[class_id]
     56         except KeyError:
     57             self.term_count_store[class_id] = {}
     58         for term in count:
     59             try:
     60                 self.term_count_store[class_id][term] += count[term]
     61             except KeyError:
     62                 self.term_count_store[class_id][term] = count[term]
     63         try:
     64             self.data['class_term_count'][class_id] += document_source.__len__()
     65         except KeyError:
     66             self.data['class_term_count'][class_id] = document_source.__len__()
     67         try:
     68             self.data['class_doc_count'][class_id] += 1
     69         except KeyError:
     70             self.data['class_doc_count'][class_id] = 1
     71         self.total_term_count += document_source.__len__()
     72         self.total_doc_count += 1
     73         self.compute_beta_priors()
     74         return True
     75         
     76     def classify(self, document_input):
     77         if not self.total_doc_count: raise ClassifierNotTrainedException()
     78         
     79         term_freq_matrix = Counter(document_input)
     80         arg_max_matrix = []
     81         for class_id in self.data['class_doc_count']:
     82             summation = 0
     83             for term in document_input:
     84                 try:
     85                     conditional_probability = (self.term_count_store[class_id][term] + 1)
     86                     conditional_probability = conditional_probability / (self.data['class_term_count'][class_id] + self.total_doc_count)
     87                     summation += term_freq_matrix[term] * math.log(conditional_probability)
     88                 except KeyError:
     89                     break
     90             arg_max = summation + self.data['beta_priors'][class_id]
     91             arg_max_matrix.insert(0, (class_id, arg_max))
     92         arg_max_matrix.sort(key=lambda x:x[1])
     93         return (arg_max_matrix[-1][0], arg_max_matrix[-1][1])
     94         
     95     def compute_beta_priors(self):
     96         if not self.total_doc_count: raise ClassifierNotTrainedException()
     97         
     98         for class_id in self.data['class_doc_count']:
     99             tmp = self.data['class_doc_count'][class_id] / self.total_doc_count
    100             self.data['beta_priors'][class_id] = math.log(tmp)
  • 相关阅读:
    Myeclipse新建 配置Hibernate
    MyEclipse从数据库表反向生成实体类之Hibernate方式(反向工程)
    简单使用JSON,JavaScript读取JSON文本(三)
    简单使用JSON,通过JSON 字符串来创建对象(二)
    简单使用JSON,JavaScript中创建 JSON 对象(一)
    【某deed网测题】D
    【题解】ACMICPC 2015 final L 哈弗曼树
    TC SRM 659 DIV1 500pt 插头DP
    BC#40D GCD值统计
    MS电面3轮
  • 原文地址:https://www.cnblogs.com/hhh5460/p/4319427.html
Copyright © 2011-2022 走看看