zoukankan      html  css  js  c++  java
  • 机器学习实战---朴素贝叶斯算法使用K折交叉验证

    一:训练模型、实现预测函数

    import numpy as np
    import re
    import random
    
    def loadDataSet():
        DataSetList=[]  #全部数据集
        classVec = []    #标签值
    
        #获取数据
        SpamPath = "email/spam/{}.txt"  #获取文件路径
        HamPath = "email/ham/{}.txt"
    
        for i in range(1,26): #两个路径各有25个文件
            documentData = open(SpamPath.format(i),'r', encoding='ISO-8859-1').read()
            #使用正则进行分割,除了空格、还有标点都可以用于分割
            WordVec = re.split(r"W+",documentData)  #W*表示匹配多个非字母、数字、下划线的字符
            DataSetList.append([item for item in WordVec if len(item)>0])
            classVec.append(1)
            documentData = open(HamPath.format(i), 'r', encoding='ISO-8859-1').read()
            # 使用正则进行分割,除了空格、还有标点都可以用于分割
            WordVec = re.split(r"W+", documentData)  # W*表示匹配多个非字母、数字、下划线的字符
            DataSetList.append([item for item in WordVec if len(item)>0])
            classVec.append(0)
        return DataSetList,classVec
    
    def createVocabList(dataSet):   #创建词汇列表 将多个文档单词归为一个
        VocabSet = set([]) #使用集合,方便去重
        for document in dataSet:
            VocabSet = VocabSet | set(document)   #保持同类型
        VocabList = list(VocabSet)
        VocabList.sort()
        return VocabList
    
    def WordSet2Vec(VocaList,inputSet):  #根据我们上面得到的全部词汇列表,将我们输入得inputSet文档向量化
        returnVec = [0]*len(VocaList)
        for word in inputSet:
            if word in VocaList:
                returnVec[VocaList.index(word)] = 1
            else:
                print("the word: %s is not in my Vocabulary!"%word)
        return returnVec
    
    def trainNB0(trainMatrix,trainCategory):    #训练朴素贝叶斯模型 传入numpy数组类型 trainMatrix所有文档词汇向量矩阵(m*n矩阵 m个文档,每个文档都是n列,代表词汇向量大小),trainCategory每篇文档得标签
        numTrainDoc = len(trainMatrix)  #文档数量
        pC1 = np.sum(trainCategory)/numTrainDoc    #p(c1)得概率   p(c0)=1-p(c1)
        wordVecNum = len(trainMatrix[0])    #因为每个文档转换为词汇向量后都是一样得长度
    
        #初始化p1,p0概率向量---改进为拉普拉斯平滑
        p1VecNum,p0VecNum = np.ones(wordVecNum),np.ones(wordVecNum)
        p1Sum,p0Sum = 2.0,2.0   #N*1    N表示分类数
    
        #循环每一个文档
        for i in range(numTrainDoc):
            if trainCategory[i] == 1: #侮辱性文档
                p1VecNum += trainMatrix[i]  #统计侮辱性文档中,每个单词出现频率
                p1Sum += np.sum(trainMatrix[i]) #统计侮辱性文档中出现得全部单词数   每个单词出现概率就是单词出现频率/全部单词
            else:   #正常文档
                p0VecNum += trainMatrix[i]
                p0Sum += np.sum(trainMatrix[i])
    
        p1Vect = np.log(p1VecNum / p1Sum)   #统计各类文档中的单词出现频率
        p0Vect = np.log(p0VecNum / p0Sum)   #使用对数避免下溢
    
        return p1Vect,p0Vect,pC1
    
    def classifyNB(testVec,p0Vec,p1Vec,pC1):
        p1 = sum(testVec*p1Vec)+np.log(pC1) #使用对数之后变为求和
        p0 = sum(testVec*p0Vec)+np.log(1-pC1)
        if p1 > p0:
            return 1
        else:
            return 0

    二:实现K折交叉验证法---k=5

    def OneCrossValidate(trainSet,trainCls,testSet,testCls):
        #训练模型
        p1Vect,p0Vect,pC1 = trainNB0(np.array(trainSet),np.array(trainCls))
        err_count = 0
        #验证集进行测试
        for i in range(10):
            c = classifyNB(np.array(testSet[i]),p0Vect,p1Vect,pC1)
            if c != testCls[i]:
                err_count += 1
    
        return err_count/10
    
    
    def KCrossValidate(trainMat,trainClassVec):    #K折交叉验证 5
        randIdx = list(range(50))
        random.shuffle(randIdx)
        error_radio = 0.0
    
        for i in range(5):  #5次
            rdInd = randIdx #随机索引
            #选取训练集、验证集索引
            trainSet = []
            trainCls = []
            testSet = []
            testCls = []
            testSetIdx = set(randIdx[10*i:10*i+10])  # 训练集
            trainSetIdx = set(rdInd)-testSetIdx  # 验证集
            #选取训练集、验证集数据
            for idx in trainSetIdx:
                trainSet.append(trainMat[idx])
                trainCls.append(trainClassVec[idx])
    
            for idx in testSetIdx:
                testSet.append(trainMat[idx])
                testCls.append(trainClassVec[idx])
            print(OneCrossValidate(trainSet,trainCls,testSet,testCls))
            error_radio += OneCrossValidate(trainSet,trainCls,testSet,testCls)
    
        return error_radio/5
    
    DocData,classVec = loadDataSet()
    voclist = createVocabList(DocData)
    
    #获取全部文档词汇向量矩阵
    trainMulList = []
    for doc in DocData:
        trainMulList.append(WordSet2Vec(voclist,doc))
    
    print(KCrossValidate(trainMulList,classVec))

  • 相关阅读:
    spring boot多数据源配置示例
    Java 8 Concurrency Tutorial--转
    ibatis annotations 注解方式返回刚插入的自增长主键ID的值--转
    mysql 字符串的处理
    How To Do @Async in Spring--转
    Resolving Problems installing the Java JCE Unlimited Strength Jurisdiction Policy Files package--转
    mysql导入数据,涉及到时间转换,乱码问题解决
    @Query Annotation in Spring Data JPA--转
    hive表信息查询:查看表结构、表操作等--转
    python时间戳
  • 原文地址:https://www.cnblogs.com/ssyfj/p/13252431.html
Copyright © 2011-2022 走看看