zoukankan      html  css  js  c++  java
  • 统计学习方法 | 朴素贝叶斯 | python实现

    本文实现了李航教授的《统计学习方法》一书中第4章朴素贝叶斯法中的算法,包括算法4.1(朴素贝叶斯算法)和在此基础上改进的贝叶斯估计。文末整理了在实现算法过程中遇到问题记录的笔记。

    朴素贝叶斯法


    首先训练朴素贝叶斯模型,对应算法4.1(1),分别计算先验概率及条件概率,分别存在字典priorP和condP中(初始化函数中定义)。其中,计算一个向量各元素频率的操作反复出现,定义为count函数。

    # 初始化函数定义了先验概率和条件概率字典,并训练模型
    def __init__(self, data, label):
        self.priorP = {}
        self.condP = {}
        self.train(data, label)
    

    count函数,输入一个向量,输出一个字典,包含各元素频率

    # 给一个向量,返回字典,包含不同元素的频率。可以引用collections中的Counter函数来实现
    # 这个函数可以改进,懒得弄了
    def count(self, vec):
        np.array(vec)
        keys = np.unique(vec)
        p = {}
        for key in keys:
            n = np.sum(np.isin(vec, key) + 0)  # 加0可以使布尔向量变为0-1向量
            p[key] = n/len(vec)  # 计算频率
        return p
    

    训练函数,关于condP的保存下面有详细说明

    def train(self, data, label):
        m, n = np.shape(data)
        # 计算先验概率
        self.priorP = self.count(label)
        print("priorP:", self.priorP)
        # 计算条件概率
        classes = np.unique(label)
        for c in classes:
            subset = [data[i] for i in range(m) if label[i] == c]  # 取Y=ck的子集
            for j in range(n):  # 遍历每一个特征,分别求条件概率
                self.condP[str(c)+" "+str(j)] = self.count([x[j] for x in subset])
        print("condP:", self.condP)
    

    对于条件概率condP的保存,将每个特征关于Y=ck的条件概率都存为一个字典,再存入字典condP中,key设为 “ck j” ,其中 ck 为 Y 的类别,j 表示第 j 个特征。训练例4.1得到的模型如下, condp中的'-1 0' 项即表示Y=-1条件下,P(x0=1)=0.5, P(x0=2)=0.333, P(x0=3)=0.166. 为了显示方便,只给出了小数点后3位。

    priorP: {-1: 0.4, 1: 0.6}
    condP: {'-1 0': {1: 0.5, 2: 0.333, 3: 0.166}, 
            '-1 1': {'L': 0.166, 'M': 0.333, 'S': 0.5}, 
            '1 0': {1: 0.222, 2: 0.333, 3: 0.444}, 
            '1 1': {'L': 0.444, 'M': 0.444, 'S': 0.111}}
    

    训练之后,对给定X进行预测,结果保存在字典preP中

    def predict(self, x):
        preP = {}
        for c in self.priorP.keys():
            preP[c] = self.priorP[c]
            for i, features in enumerate(x):
                preP[c] *= self.condP[str(c)+" "+str(i)][features]
        print("probability: ", preP)
        print("prediction: ", max(preP, key=preP.get))
    

    结果:

    probability:  {-1: 0.06666666666666667, 1: 0.02222222222222222}
    prediction:  -1
    

    贝叶斯分类

    考虑到概率可能为0,在随机变量各个取值的频数上赋予一个正数lamda,lamda为0时即为极大似然估计;lamda取1时称为拉普拉斯平滑。

    先验概率变为(a)

    条件概率变为(b)

    在之前的算法基础上改进,添加一个字典变量rangeOfFeature来保存每个特征的取值个数,定义在初始化函数中。

    self.rangeOfFeature = {}  # 保存每个特征的取值个数
    

    公式(a)和(b)形式相同,将 K 或Sj 作为参数传入count函数,lamda缺省为0:

    def count(self, vec, classNum, lamda=0):
        keys = set(vec)
        p = {}
        for key in keys:
            n = np.sum(np.isin(vec, key) + 0)
            p[key] = (n+lamda)/(len(vec)+classNum*lamda) 
        return p
    

    训练函数变为

    def train(self, data, label, lamda=0):
        m, n = np.shape(data)
        # 计算rangeOfFeature
        for j in range(n):
            self.rangeOfFeature[j] = len(set([x[j] for x in data]))
        classes = set(label)
        # 计算先验概率
        self.priorP = self.count(label, len(classes), lamda)
        print("priorP:", self.priorP)
        # 计算条件概率
        for c in classes:
            subset = [data[i] for i in range(m) if label[i] == c]
            for j in range(n):
                self.condP[str(c)+" "+str(j)] = self.count([x[j] for x in subset], self.rangeOfFeature[j], lamda)
        print("condP:", self.condP)
    

    其他不变,对之前的实例运行

    bayes = Bayes(dataSet, labels, 1)
    bayes.predict([2, "S"])
    

    结果如下:

    probability:  {1: 0.0326797385620915, -1: 0.06100217864923746}
    prediction:  -1
    
    笔记
    • 贝叶斯定理

      img

    • max(dict, key = dict.get) 获得字典dictvalue最大的值的键

    • bool型列表转换为0-1

      • 变量后加0
      • booldata.astype(int) 类型转换
    • for i, value in enumerate(['A', 'B', 'C']): 把list变成索引-元素对用enumerate函数,这样就可以在for循环中同时迭代索引和元素本身

    • [a[j] for a in data] 取data第 j 列

    • np.isin(a,b) 判断a中每个元素是否在b中,返回与a形状相同的bool数组

    • np.unique(a) 去重并排序

    代码下载:3-Bayes.py

  • 相关阅读:
    VC内存泄露检查工具:Visual Leak Detector
    ArcGIS Server 开发系列(五)自定义 Toolbar 工具2 (转载于Flyingis)
    Arcgis Server系列 安装与配置
    ArcGIS Server 开发系列(四)ArcGIS Server data sources 开发 (转载于Flyingis)
    ArcGIS Server 开发系列(五)自定义 Toolbar 工具 (转载于Flyingis)
    ArcMap的地图缓存MapCache
    C# 字符串 合并时 + 和 stringbulilder 的区别是什么?
    ArcGIS Server 体系结构
    ArcGIS Server 开发系列(三)漫游 Graphics data sources (转载于Flyingis)
    ArcGIS Server .Net Web ADF体系结构
  • 原文地址:https://www.cnblogs.com/hellozy/p/11243567.html
Copyright © 2011-2022 走看看