zoukankan      html  css  js  c++  java
  • 贝叶斯决策

    贝叶斯决策

    • 简单例子引入
    • 先验概率
    • 后验概率
    • 最小错误率决策
    • 最小风险贝叶斯决策

    简单的例子

      正常情况下,我们可以快速的将街上的人分成男和女两类。这里街上的人就是我们观测到的样本,将每一个人分成男、女两类就是我们做决策的过程。上面的问题就是一个分类问题。

      分类可以看作是一种决策,即我们根据观测对样本做出应归属哪一类的决策。

      假定我手里握着一枚硬币,让你猜是多少钱的硬币,这其实就可以看作一个分类决策的问题:你需要从各种可能的硬币中做出一个决策。硬币假设面值有1角、5角、1块。

      如果事先告知这枚硬币只可能是一角或者五角,那么问题就是一个两分类问题。

    先验概率

          


    最小错误率

            

    后验概率

    决策

     

     最小错误率决策

     

    最小风险贝叶斯决策

    最小风险决策

     

     

     

     

     贝叶斯决策理论的分类方法

    总结

                  

    Bayes.java

    package byas;
    
    import com.google.common.collect.Lists;
    import com.google.common.collect.Sets;
    import org.apache.commons.math3.linear.MatrixUtils;
    import org.apache.commons.math3.linear.RealMatrix;
    import org.apache.commons.math3.linear.RealVector;
    import org.lionsoul.jcseg.ASegment;
    import org.lionsoul.jcseg.core.*;
    
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.HashSet;
    
    import static org.apache.commons.math3.util.FastMath.log;
    
    
    public class Bayes {
        //创建JcsegTaskConfig分词任务实例
        //即从jcseg.properties配置文件中初始化的配置
        public static JcsegTaskConfig config = new JcsegTaskConfig();
        public static ADictionary dic = DictionaryFactory
                .createDefaultDictionary(config);
    
        //生成数据
        public static Object[] createdata() throws IOException {
            ArrayList<ArrayList<String>> retList = Lists.newArrayList();
            ArrayList<Integer> labels = Lists.newArrayList();
            ASegment seg = null;
            try {
                seg = (ASegment) SegmentFactory
                        .createJcseg(JcsegTaskConfig.SIMPLE_MODE,
                                new Object[]{config, dic});
            } catch (JcsegException e) {
                e.printStackTrace();
            }
    
            /*IWord word;
            while ( (word = seg.next()) != null ) {
                System.out.println(word.getValue());
            }
            /*String title = article.getTitle();
            String content = article.getContent();
    
            List<Term> termList = new ArrayList<Term>();
            List<String> wordList = new ArrayList<String>();
            Map<String,Set<String>> words = new HashMap<String, Set<String>>();
            Queue<String> que = new LinkedList<String>();
            try {
                if(seg!=null){
                    seg.reset(new StringReader(title + content));
                    IWord word;
                    while ( (word = seg.next()) != null ) {
                        if(shouldInclude(word.getValue())){
                            wordList.add(word.getValue());
                        }
                    }
                }
    
            } catch (IOException e) {
                e.printStackTrace();
            }*/
    
            /*retList.add(Lists.newArrayList("my", "dog", "has", "flea", "problems", "help", "please"));
            retList.add(Lists.newArrayList("maybe", "not", "take", "him", "to", "dog", "park", "stupid"));
            retList.add(Lists.newArrayList("my", "dalmation", "is", "so", "cute", "I", "love", "him"));
            retList.add(Lists.newArrayList("stop", "posting", "stupid", "worthless", "garbage"));
            retList.add(Lists.newArrayList("mr", "licks", "ate", "my", "steak", "how", "to", "stop", "him"));
            retList.add(Lists.newArrayList("quit", "buying", "worthless", "dog", "food", "stupid"));
            ArrayList<Integer> labels = Lists.newArrayList(0,1,0,1,0,1);*/
            return new Object[]{retList,labels};
        }
    
    
        //获取单词set
        public static ArrayList<String> createVocabSet(ArrayList<ArrayList<String>> lists){
            HashSet<String> retSet = Sets.newHashSet();
            for(ArrayList<String> list : lists){
                for(String str : list){
                    retSet.add(str);
                }
            }
    
            return Lists.newArrayList(retSet);
        }
    
        //计算set中包含的单词数量
        public static double[] bagOfWords2VecMN(ArrayList<String> set,ArrayList<String> inputData){
            double[] returnVec = new double[set.size()];
            for (int i = 0; i < inputData.size(); i++) {
                if(set.contains(inputData.get(i))){
                    returnVec[set.indexOf(inputData.get(i))]++;
                }
            }
            return returnVec;
        }
    
        //训练
        public static Object[] trainNB(RealMatrix realMatrix,ArrayList<Integer> labels){
    
            int numTrainDocs = realMatrix.getRowDimension();
            int numWords = realMatrix.getRow(0).length;
            int count = 0;
            for(int l : labels){
                count += l;
            }
    
            float pAbusive = (float)count / numTrainDocs;
            //生成单词矩阵
            RealMatrix p0Matrix = MatrixUtils.createRealMatrix(1, numWords);
            p0Matrix = oneNums(p0Matrix);
            RealMatrix p1Matrix = MatrixUtils.createRealMatrix(1,numWords);
            p1Matrix = oneNums(p1Matrix);
    
            float p0Denom = 2;
            float p1Denom = 2;
    
            //不同类别单词增加,总单词增加
            for (int i = 0; i < labels.size(); i++) {
                if(labels.get(i)==1){
                    p1Matrix = p1Matrix.add(realMatrix.getRowMatrix(i));
                    p1Denom += sumMatrix(realMatrix.getRowMatrix(i));
                }else{
                    p0Matrix = p0Matrix.add(realMatrix.getRowMatrix(i));
                    p0Denom += sumMatrix(realMatrix.getRowMatrix(i));
                }
            }
    
            //单词概率矩阵
            RealMatrix p0 = logMatrix(p0Matrix.scalarMultiply(1 / p0Denom));
            RealMatrix p1 = logMatrix(p1Matrix.scalarMultiply(1 / p1Denom));
            return new Object[]{p0,p1,pAbusive};
        }
    
    
        /**
         * 矩阵填充1
         * @param realMatrix
         * @return
         */
        public static RealMatrix oneNums(RealMatrix realMatrix){
            for(int i=0;i<realMatrix.getColumnDimension();i++){
                realMatrix.setColumn(i,new double[]{1});
            }
            return realMatrix;
        }
    
        /**
         * 计算矩阵元素和
         * @param realMatrix
         * @return
         */
        public static float sumMatrix(RealMatrix realMatrix){
            float num = 0;
            double[] rows = realMatrix.getRow(0);
            for(double row : rows){
                num += row;
            }
            return num;
        }
    
        /**
         * 矩阵元素log操作
         * @param realMatrix
         * @return
         */
        public static RealMatrix logMatrix(RealMatrix realMatrix){
            double[] rows = realMatrix.getRow(0);
            double[] newRows = new double[rows.length];
            for (int i = 0; i < rows.length; i++) {
                newRows[i] = log(rows[i]);
            }
            realMatrix.setRow(0,newRows);
            return realMatrix;
        }
    
        /**
         * 矩阵元素相乘
         * @param m1
         * @param m2
         * @return
         */
        public static RealMatrix multiply(RealMatrix m1,RealMatrix m2){
            RealVector r1 = m1.getRowVector(0);
            RealVector r2 = m2.getRowVector(0);
            RealMatrix m = MatrixUtils.createRealMatrix(m1.getRowDimension(),m1.getColumnDimension());
            m.setRowVector(0,r1.ebeMultiply(r2));
    
            return m;
        }
    
    
        /**
         * 验证方法
         * @param realMatrix
         * @param p0M
         * @param p1M
         * @return
         */
        public static int classify(RealMatrix realMatrix,Object p0M,Object p1M){
    
            float p0 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p0M))+log(1.0-0.5));
            float p1 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p1M))+log(0.5));
            if(p0>p1){
                return 0;
            }
            return 1;
        }
    
    
    
        public static void main(String[] args) throws IOException {
            /*Object[] retData = createData();
            ArrayList<String> set = createVocabSet((ArrayList<ArrayList<String>>) retData[0]);
            ArrayList<ArrayList<String>> lists = (ArrayList<ArrayList<String>>) retData[0];
            RealMatrix m = MatrixUtils.createRealMatrix(lists.size(),set.size());
            for (int i = 0; i < lists.size(); i++) {
                m.setRow(i,bagOfWords2VecMN(set,lists.get(i)));
            }
    
            Object[] retP = trainNB(m, (ArrayList<Integer>) retData[1]);
    
            ArrayList<String> test = Lists.newArrayList("love");
            RealMatrix m1 = MatrixUtils.createRealMatrix(1,set.size());
            m1.setRow(0,bagOfWords2VecMN(set,test));
            System.out.println(classify(m1,retP[0],retP[1]));*/
            createdata();
    
    
        }
    }
  • 相关阅读:
    洛谷—— P2234 [HNOI2002]营业额统计
    BZOJ——3555: [Ctsc2014]企鹅QQ
    CodeVs——T 4919 线段树练习4
    python(35)- 异常处理
    August 29th 2016 Week 36th Monday
    August 28th 2016 Week 36th Sunday
    August 27th 2016 Week 35th Saturday
    August 26th 2016 Week 35th Friday
    August 25th 2016 Week 35th Thursday
    August 24th 2016 Week 35th Wednesday
  • 原文地址:https://www.cnblogs.com/zlslch/p/6789129.html
Copyright © 2011-2022 走看看