zoukankan      html  css  js  c++  java
  • 《统计学习方法》——朴素贝叶斯代码实现

    朴素贝叶斯分类原理

    对于给定的训练数据集,首先基于特征条件独立假设学习输入/输出的联合概率分布;然后基于此模型,对给定的输入(x),利用贝叶斯定理求出后验概率最大的输出(y)

    特征独立性假设:在利用贝叶斯定理进行预测时,我们需要求解条件概率(P(x|y_k)=P(x_1,x_2,...,x_n|y_k)P(x|y_k)=P(x_1,x_2,...,x_n|y_k)),它的参数规模是指数数量级别的,假设第i维特征可取值的个数有(T_i)个,类别取值个数为k个,那么参数个数为:(kprod_{i=1}^nT_i)。这显然不可行,所以朴素贝叶斯算法对条件概率分布作出了独立性的假设,实际上是为了简化计算。

    import numpy as np
    import math
    from sklearn.datasets import load_iris 
    from sklearn.model_selection import train_test_split
    from collections import Counter
    

    从sklearn数据集中加载鸢尾花分类数据集

    iris = load_iris()
    X, Y = iris.data, iris.target
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3)
    print('X_train[0]: {}'.format(X_train[0]))
    print('Y_train[0]: {}'.format(Y_train[0]))
    # 查看训练集各个类别的数量
    for l in set(Y_train):
        print('label: %s ,count: %d' % (l, len(Y_train[Y_train==l])))
    

    代码输出:

    X_train[0]: [5.2 3.5 1.5 0.2]
    Y_train[0]: 0
    label: 0 ,count: 35
    label: 1 ,count: 32
    label: 2 ,count: 38
    

    高斯模型的朴素贝叶斯:

    对于取值是连续型的特征变量,用离散型特征的求解方法时会有很多特征取值的条件概率为0,所以我们使用高斯模型的朴素贝叶斯,它假设每一维特征都服从高斯分布。即:

    [P(x_i | y_k)=frac{1}{sqrt{2pi}sigma_{y_k,i}}exp(-frac{(x_i-mu_{y_k,i})^2}{2sigma^2_{y_k,i}}) ]

    (mu_{y_k,i})是分类为(y_k)的样本中,第(i)维特征取值的均值;(sigma_{y_k,i}^2)为其方差

    class GaussianNaiveBayes:
        def __init__(self):
            self.parameters = {}
            self.prior = {}
            
        # 训练过程就是求解先验概率和高斯分布参数的过程
        # X:(样本数,特征维度) Y:(样本数,)
        def fit(self, X, Y):
            self._get_prior(Y)  # 计算先验概率
            labels = set(Y)
            for label in labels:
                samples = X[Y==label]
                # 计算高斯分布的参数:均值和标准差
                means = np.mean(samples, axis=0)
                stds = np.std(samples, axis=0)
     
                self.parameters[label] = {
                    'means': means,
                    'stds': stds
                }
             
        # x:单个样本
        def predict(self, x):
            probs = sorted(self._cal_likelihoods(x).items(), key=lambda x:x[-1])  # 按概率从小到大排序
            return probs[-1][0]
        
        # 计算模型在测试集的准确率
        # X_test:(测试集样本个数,特征维度)
        def evaluate(self, X_test, Y_test):
            true_pred = 0
            for i, x in enumerate(X_test):
                label = self.predict(x)
                if label == Y_test[i]:
                    true_pred += 1
        
            return true_pred / len(X_test)
                
        # 计算每个类别的先验概率
        def _get_prior(self, Y):
            cnt = Counter(Y)
            for label, count in cnt.items():
                self.prior[label] = count / len(Y)
            
        # 高斯分布
        def _gaussian(self, x, mean, std):
            exponent = math.exp(-(math.pow(x - mean, 2)/(2 * math.pow(std, 2))))
            return (1 / (math.sqrt(2 * math.pi) * std)) * exponent
        
        # 计算样本x属于每个类别的似然概率
        def _cal_likelihoods(self, x):
            likelihoods = {}
            for label, params in self.parameters.items():
                means = params['means']
                stds = params['stds']
                prob = self.prior[label]
                # 计算每个特征的条件概率,P(xi|yk)
                for i in range(len(means)):
                    prob *= self._gaussian(x[i], means[i], stds[i])
                likelihoods[label] = prob
            return likelihoods
    

    在测试集上评估分类器:

    gussian_nb = GaussianNaiveBayes()
    gussian_nb.fit(X_train, Y_train)
    print('样本[4.4, 3.2, 1.3, 0.2]的预测结果: %d' % gussian_nb.predict([4.4, 3.2, 1.3, 0.2]))
    print('测试集的准确率: %f' % gussian_nb.evaluate(X_test, Y_test))
    

    代码输出:

    样本[4.4, 3.2, 1.3, 0.2]的预测结果: 0
    测试集的准确率: 0.955556
    

    与scikit-learn的实现对比

    from sklearn.naive_bayes import GaussianNB
    
    clf = GaussianNB()
    clf.fit(X_train, Y_train)
    print('(sklearn)样本[4.4, 3.2, 1.3, 0.2]的预测结果: %d' % clf.predict([[4.4,  3.2,  1.3,  0.2]])[0])
    print('(sklearn)测试集的准确率: %f' % clf.score(X_test, Y_test))
    

    代码输出:

    (sklearn)样本[4.4, 3.2, 1.3, 0.2]的预测结果: 0
    (sklearn)测试集的准确率: 0.955556
  • 相关阅读:
    Java日志体系(1) —— 那些年那些事,那些日志的历史
    直播工作原理
    【PAT乙级 】1003. 我要通过!
    [牛客网刷题]被3整除
    [牛客网刷题]牛牛找工作
    Mybatis的简单分析
    数位DP
    正则表达式
    能量球
    从此,我们相伴,不离不弃
  • 原文地址:https://www.cnblogs.com/irvingluo/p/14460406.html
Copyright © 2011-2022 走看看