zoukankan      html  css  js  c++  java
  • 朴素贝叶斯

    # coding:utf-8
    from numpy import *
    
    def loadDataSet():
        postingList=[['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
                     ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
                     ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
                     ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
                     ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
                     ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
        classVec = [0,1,0,1,0,1]    #1 is abusive, 0 not
        return postingList,classVec;
    
    def createVocabList(dataSet):
        vocabSet=set([]);
        for document in dataSet:
            vocabSet=vocabSet|set(document);
        return list(vocabSet);
    
    def setOfWords2Vec(vocabList,inputSet):
        returnVec=[0]*len(vocabList);
        for word in inputSet:
            if word in vocabList:
                returnVec[vocabList.index(word)]=1;
            else:
                print "the word: %s is not in my Vocabluary!" % word;
        return returnVec;
    
    def bagOfWords2VecMN(vocabList,inputSet):
        returnVec=[0]*len(vocabList);
        for word in inputSet:
            if word in vocabList:
                returnVec[vocabList.index(word)]+=1;
        return  returnVec;
    
    
    def trainNB0(trainMatrix,trainCategory):
        numTrainDocs=len(trainMatrix);
        numWords=len(trainMatrix[0]);
        pAbusive=sum(trainCategory)/float(numTrainDocs);
        p0Num=ones(numWords);
        p1Num=ones(numWords);
        p0Denom=2.0;
        p1Denom=2.0;
        for i in range(numTrainDocs):
            if trainCategory[i]==1:
                p1Num+=trainMatrix[i];
                p1Denom+=sum(trainMatrix[i]);
            else:
                p0Num+=trainMatrix[i];
                p0Denom+=sum(trainMatrix[i]);
        p1Vect=log(p1Num/p1Denom);
        p0Vect=log(p0Num/p0Denom);
        return p0Vect,p1Vect,pAbusive;
    
    def classifyNB(vec2Classify,p0Vec,p1Vec,pClass1):
        p1=sum(vec2Classify*p1Vec)+log(pClass1);
        p0=sum(vec2Classify*p0Vec)+log(1-pClass1);
        if p1>p0:
            return 1;
        else:
            return 0;
    
    def testingNB():
        listOPosts,listClasses=loadDataSet();
        myVocabList=createVocabList(listOPosts);
        trainMat=[];
        for postinDoc in listOPosts:
            trainMat.append(setOfWords2Vec(myVocabList,postinDoc));
        p0V,p1V,pAb=trainNB0(array(trainMat),array(listClasses));
        testEntry=['love','my','dalmation'];
        thisDoc=array(setOfWords2Vec(myVocabList,testEntry));
        print testEntry,'classified as: ',classifyNB(thisDoc,p0V,p1V,pAb);
        testEntry = ['stupid','garbage'];
        thisDoc = array(setOfWords2Vec(myVocabList, testEntry));
        print testEntry, 'classified as: ', classifyNB(thisDoc, p0V, p1V, pAb);
    
    def textParse(bigString):
        import re
        listOfTokens=re.split(r'W*',bigString);
        return  [word.lower() for word in listOfTokens if len(word) >2];
    
    def spamTest():
        docList=[];classList=[];fullText=[];
        for i in range(1,26):
            wordList=textParse(open('email/spam/%d.txt' % i).read());
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(1);
            wordList=textParse(open('email/ham/%d.txt' % i).read());
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(0);
        vocabList=createVocabList(docList);
        trainingSet=range(50);testSet=[];
        for i in range(10):
            randIndex=int(random.uniform(0,len(trainingSet)));
            testSet.append(trainingSet[randIndex]);
            del(trainingSet[randIndex]);
        trainMat=[];trainClasses=[];
        for docIndex in trainingSet:
            trainMat.append(setOfWords2Vec(vocabList,docList[docIndex]));
            trainClasses.append(classList[docIndex]);
        p0V,p1V,pSpam=trainNB0(array(trainMat),array(trainClasses));
        errorCount=0;
        for docIndex in testSet:
            wordVector=setOfWords2Vec(vocabList,docList[docIndex]);
            if classifyNB(array(wordVector),p0V,p1V,pSpam)!= classList[docIndex]:
                errorCount+=1;
        print 'the error rate is: ',float(errorCount)/len(testSet);
    
    def calcMostFreq(vocabList,fullText):
        import operator;
        freqDict={};
        for token in vocabList:
            freqDict[token]=fullText.count(token);
        sortedFreq=sorted(freqDict.iteritems(),key=operator.itemgetter(1),reverse=True);
        return sortedFreq[:30];
    
    
    def localWords(feed1,feed0):
        import feedparser
        docList=[];classList=[];fullText=[];
        minLen=min(len(feed1['entries']),len(feed0['entries']));
        for i in range(minLen):
            wordList=textParse(feed1['entries'][i]['summary']);
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(1);
            wordList=textParse(feed0['entries'][i]['summary']);
            docList.append(wordList);
            fullText.extend(wordList);
            classList.append(0);
        vocabList=createVocabList(docList);
        top30Words=calcMostFreq(vocabList,fullText);
        for pairW in top30Words:
            if pairW[0] in vocabList:
                vocabList.remove(pairW[0]);
        trainingSet=range(2*minLen);testSet=[];
        for i in range(20):
            randIndex=int(random.uniform(0,len(trainingSet)));
            testSet.append(trainingSet[randIndex]);
            del(trainingSet[randIndex]);
        trainMat=[];trainClasses=[];
        for docIndex in trainingSet:
            trainMat.append(bagOfWords2VecMN(vocabList,docList[docIndex]));
            trainClasses.append(classList[docIndex]);
        p0V,p1V,pSpam=trainNB0(array(trainMat),array(trainClasses));
        errorCount=0;
        for docIndex in testSet:
            wordVector=bagOfWords2VecMN(vocabList,docList[docIndex]);
            if classifyNB(array(wordVector),p0V,p1V,pSpam)!= classList[docIndex]:
                errorCount+=1;
        print 'the error rate is: ',float(errorCount)/len(testSet);
        return vocabList,p0V,p1V;
    View Code
  • 相关阅读:
    上传图片,将图片保存在腾讯云(2种方式)
    由ping所引发的思考~
    php面试上机题(2018-3-3)
    【八】jqeury之click事件[添加及删除数据]
    【七】jquery之属性attr、 removeAttr、prop[全选全不选及反选]
    【六】jquery之HTML代码/文本/值[下拉列表框、多选框、单选框的选中]
    【五】jquery之事件(focus事件与blur事件)[提示语的出现及消失时机]
    小白懂算法之基数排序
    mysql_sql199语法介绍
    Python基本编程快速入门
  • 原文地址:https://www.cnblogs.com/cherryMJY/p/8541903.html
Copyright © 2011-2022 走看看