zoukankan      html  css  js  c++  java
  • 朴素贝叶斯

    # coding:utf-8
    from numpy import *
    
    def loadDataSet():
        postingList=[['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
                     ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
                     ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
                     ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
                     ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
                     ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
        classVec = [0,1,0,1,0,1]    #1 is abusive, 0 not
        return postingList,classVec;
    
    def createVocabList(dataSet):
        vocabSet=set([]);
        for document in dataSet:
            vocabSet=vocabSet|set(document);
        return list(vocabSet);
    
    def setOfWords2Vec(vocabList,inputSet):
        returnVec=[0]*len(vocabList);
        for word in inputSet:
            if word in vocabList:
                returnVec[vocabList.index(word)]=1;
            else:
                print "the word: %s is not in my Vocabluary!" % word;
        return returnVec;
    
    def bagOfWords2VecMN(vocabList,inputSet):
        returnVec=[0]*len(vocabList);
        for word in inputSet:
            if word in vocabList:
                returnVec[vocabList.index(word)]+=1;
        return  returnVec;
    
    
    def trainNB0(trainMatrix,trainCategory):
        numTrainDocs=len(trainMatrix);
        numWords=len(trainMatrix[0]);
        pAbusive=sum(trainCategory)/float(numTrainDocs);
        p0Num=ones(numWords);
        p1Num=ones(numWords);
        p0Denom=2.0;
        p1Denom=2.0;
        for i in range(numTrainDocs):
            if trainCategory[i]==1:
                p1Num+=trainMatrix[i];
                p1Denom+=sum(trainMatrix[i]);
            else:
                p0Num+=trainMatrix[i];
                p0Denom+=sum(trainMatrix[i]);
        p1Vect=log(p1Num/p1Denom);
        p0Vect=log(p0Num/p0Denom);
        return p0Vect,p1Vect,pAbusive;
    
    def classifyNB(vec2Classify,p0Vec,p1Vec,pClass1):
        p1=sum(vec2Classify*p1Vec)+log(pClass1);
        p0=sum(vec2Classify*p0Vec)+log(1-pClass1);
        if p1>p0:
            return 1;
        else:
            return 0;
    
    def testingNB():
        listOPosts,listClasses=loadDataSet();
        myVocabList=createVocabList(listOPosts);
        trainMat=[];
        for postinDoc in listOPosts:
            trainMat.append(setOfWords2Vec(myVocabList,postinDoc));
        p0V,p1V,pAb=trainNB0(array(trainMat),array(listClasses));
        testEntry=['love','my','dalmation'];
        thisDoc=array(setOfWords2Vec(myVocabList,testEntry));
        print testEntry,'classified as: ',classifyNB(thisDoc,p0V,p1V,pAb);
        testEntry = ['stupid','garbage'];
        thisDoc = array(setOfWords2Vec(myVocabList, testEntry));
        print testEntry, 'classified as: ', classifyNB(thisDoc, p0V, p1V, pAb);
    
    def textParse(bigString):
        import re
        listOfTokens=re.split(r'W*',bigString);
        return  [word.lower() for word in listOfTokens if len(word) >2];
    
    def spamTest():
        docList=[];classList=[];fullText=[];
        for i in range(1,26):
            wordList=textParse(open('email/spam/%d.txt' % i).read());
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(1);
            wordList=textParse(open('email/ham/%d.txt' % i).read());
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(0);
        vocabList=createVocabList(docList);
        trainingSet=range(50);testSet=[];
        for i in range(10):
            randIndex=int(random.uniform(0,len(trainingSet)));
            testSet.append(trainingSet[randIndex]);
            del(trainingSet[randIndex]);
        trainMat=[];trainClasses=[];
        for docIndex in trainingSet:
            trainMat.append(setOfWords2Vec(vocabList,docList[docIndex]));
            trainClasses.append(classList[docIndex]);
        p0V,p1V,pSpam=trainNB0(array(trainMat),array(trainClasses));
        errorCount=0;
        for docIndex in testSet:
            wordVector=setOfWords2Vec(vocabList,docList[docIndex]);
            if classifyNB(array(wordVector),p0V,p1V,pSpam)!= classList[docIndex]:
                errorCount+=1;
        print 'the error rate is: ',float(errorCount)/len(testSet);
    
    def calcMostFreq(vocabList,fullText):
        import operator;
        freqDict={};
        for token in vocabList:
            freqDict[token]=fullText.count(token);
        sortedFreq=sorted(freqDict.iteritems(),key=operator.itemgetter(1),reverse=True);
        return sortedFreq[:30];
    
    
    def localWords(feed1,feed0):
        import feedparser
        docList=[];classList=[];fullText=[];
        minLen=min(len(feed1['entries']),len(feed0['entries']));
        for i in range(minLen):
            wordList=textParse(feed1['entries'][i]['summary']);
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(1);
            wordList=textParse(feed0['entries'][i]['summary']);
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(0);
        vocabList=createVocabList(docList);
        top30Words=calcMostFreq(vocabList,fullText);
        for pairW in top30Words:
            if pairW[0] in vocabList:
                vocabList.remove(pairW[0]);
        trainingSet=range(2*minLen);testSet=[];
        for i in range(20):
            randIndex=int(random.uniform(0,len(trainingSet)));
            testSet.append(trainingSet[randIndex]);
            del(trainingSet[randIndex]);
        trainMat=[];trainClasses=[];
        for docIndex in trainingSet:
            trainMat.append(bagOfWords2VecMN(vocabList,docList[docIndex]));
            trainClasses.append(classList[docIndex]);
        p0V,p1V,pSpam=trainNB0(array(trainMat),array(trainClasses));
        errorCount=0;
        for docIndex in testSet:
            wordVector=bagOfWords2VecMN(vocabList,docList[docIndex]);
            if classifyNB(array(wordVector),p0V,p1V,pSpam)!= classList[docIndex]:
                errorCount+=1;
        print 'the error rate is: ',float(errorCount)/len(testSet);
        return vocabList,p0V,p1V;
    View Code
  • 相关阅读:
    sprintboot 发布
    springmvc 常用注解
    react-navigation使用技巧
    Windows 10提示你不能访问此共享文件夹,因为你组织的安全策略阻止未经身份验证的来宾访问
    Python 精选文章
    自动化办公:python操作Excel
    VSCode 插件
    使用 Visual Studio Code(VSCode)搭建简单的Python+Django开发环境的方法步骤
    纯洁的微笑
    初进python世界之数据类型
  • 原文地址:https://www.cnblogs.com/cherryMJY/p/8541903.html
Copyright © 2011-2022 走看看