zoukankan      html  css  js  c++  java
  • 【LDA】lda模型和java代码

    几个问题:

    1、停用次应该去到什么程度??

    2、比如我选了参数topicNumber=100,结果中,其中有80个topic,每个的前几个words很好地描述了一个topic。另外的20个topic的前几个words没有描述好。这样是否说明了topicNumber=100已经足够了?

    3、LDA考虑了多少文件之间的关系?

    4、参数 alpha,beta怎么取?? alpha=K/50  ??  b=0.1(0.01) ??

    ========================================

    看了几篇LDA的文档,实在写的太好了,我只能贴点代码,表示我做过lda了

    public class LdaModel {
    
        int[][] doc;// word index array,每个文本中每个词在字典indexToTermMap中的序号
        int V, K, M;// vocabulary size, topic number, document number
        int[][] z;// topic label array,每个文本的每个词对应的topic的编号
        float alpha; // doc-topic dirichlet prior parameter
        float beta; // topic-word dirichlet prior parameter
        int[][] nmk;// given document m, count times of topic k. M*K
        int[][] nkt;// given topic k, count times of term t. K*V
        int[] nmkSum;// Sum for each row in nmk,nukSum[m]=n:也就是文档m中word的个数为n
        int[] nktSum;// Sum for each row in nkt,nkt[k]=n:被指定给topic k的term/word的个数为n
        double[][] phi;// Parameters for topic-word distribution K*V
        double[][] theta;// Parameters for doc-topic distribution M*K
        int iterations;// Times of iterations
        int saveStep;// The number of iterations between two saving
        int beginSaveIters;// Begin save model at this iteration
    
        public LdaModel(LdaGibbsSampling.modelparameters modelparam) {
            // TODO Auto-generated constructor stub
            alpha = modelparam.alpha;
            beta = modelparam.beta;
            iterations = modelparam.iteration;
            K = modelparam.topicNum;
            saveStep = modelparam.saveStep;
            beginSaveIters = modelparam.beginSaveIters;
        }
    
        public void initializeModel(Documents docSet) {
            // TODO Auto-generated method stub
            M = docSet.docs.size();
            V = docSet.termToIndexMap.size();
            nmk = new int[M][K];
            nkt = new int[K][V];
            nmkSum = new int[M];
            nktSum = new int[K];
            phi = new double[K][V];
            theta = new double[M][K];
    
            // initialize documents index array
            doc = new int[M][];
            for (int m = 0; m < M; m++) {
                // Notice the limit of memory
                int N = docSet.docs.get(m).docWords.length;
                doc[m] = new int[N];
                for (int n = 0; n < N; n++) {
                    doc[m][n] = docSet.docs.get(m).docWords[n];
                }
            }
    
            // initialize topic label z for each word
            z = new int[M][];
            for (int m = 0; m < M; m++) {
                int N = docSet.docs.get(m).docWords.length;
                z[m] = new int[N];
                for (int n = 0; n < N; n++) {
                    int initTopic = (int) (Math.random() * K);// From 0 to K - 1
                    z[m][n] = initTopic;
                    // number of words in doc m assigned to topic initTopic add 1
                    nmk[m][initTopic]++;
                    // number of terms doc[m][n] assigned to topic initTopic add 1
                    nkt[initTopic][doc[m][n]]++;
                    // total number of words assigned to topic initTopic add 1
                    nktSum[initTopic]++;
                }
                // total number of words in document m is N
                nmkSum[m] = N;
            }
        }
    
        public void inferenceModel(Documents docSet) throws IOException {
            // TODO Auto-generated method stub
            if (iterations < saveStep + beginSaveIters) {
                System.err
                        .println("Error: the number of iterations should be larger than "
                                + (saveStep + beginSaveIters));
                System.exit(0);
            }
            for (int i = 0; i < iterations; i++) {
                System.out.println("Iteration " + i);
                if ((i >= beginSaveIters)
                        && (((i - beginSaveIters) % saveStep) == 0)) {
                    // Saving the model
                    System.out.println("Saving model at iteration " + i + " ... ");
                    // Firstly update parameters
                    updateEstimatedParameters();
                    // Secondly print model variables
                    saveIteratedModel(i, docSet);
                }
    
                // Use Gibbs Sampling to update z[][]
                for (int m = 0; m < M; m++) {
                    int N = docSet.docs.get(m).docWords.length;
                    for (int n = 0; n < N; n++) {
                        // Sample from p(z_i|z_-i, w)
                        int newTopic = sampleTopicZ(m, n);
                        z[m][n] = newTopic;
                    }
                }
            }
        }
    
        private void updateEstimatedParameters() {
            // TODO Auto-generated method stub
            for (int k = 0; k < K; k++) {
                for (int t = 0; t < V; t++) {
                    phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
                }
            }
    
            for (int m = 0; m < M; m++) {
                for (int k = 0; k < K; k++) {
                    theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
                }
            }
        }
    
        private int sampleTopicZ(int m, int n) {
            // TODO Auto-generated method stub
            // Sample from p(z_i|z_-i, w) using Gibbs upde rule
    
            // Remove topic label for w_{m,n}
            int oldTopic = z[m][n];
            nmk[m][oldTopic]--;
            nkt[oldTopic][doc[m][n]]--;
            nmkSum[m]--;
            nktSum[oldTopic]--;
    
            // Compute p(z_i = k|z_-i, w)
            double[] p = new double[K];
            for (int k = 0; k < K; k++) {
                p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta)
                        * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
            }
    
            // Sample a new topic label for w_{m, n} like roulette
            // Compute cumulated probability for p
            for (int k = 1; k < K; k++) {
                p[k] += p[k - 1];
            }
            double u = Math.random() * p[K - 1]; // p[] is unnormalised
            int newTopic;
            for (newTopic = 0; newTopic < K; newTopic++) {
                if (u < p[newTopic]) {
                    break;
                }
            }
    
            // Add new topic label for w_{m, n}
            nmk[m][newTopic]++;
            nkt[newTopic][doc[m][n]]++;
            nmkSum[m]++;
            nktSum[newTopic]++;
            return newTopic;
        }
    
        public void saveIteratedModel(int iters, Documents docSet)
                throws IOException {
            // TODO Auto-generated method stub
            // lda.params lda.phi lda.theta lda.tassign lda.twords
            // lda.params
            String resPath = LdaConfig.OUTPUTFILE_PATH;
            String modelName = "lda_" + iters;
            ArrayList<String> lines = new ArrayList<String>();
            lines.add("alpha = " + alpha);
            lines.add("beta = " + beta);
            lines.add("topicNum = " + K);
            lines.add("docNum = " + M);
            lines.add("termNum = " + V);
            lines.add("iterations = " + iterations);
            lines.add("saveStep = " + saveStep);
            lines.add("beginSaveIters = " + beginSaveIters);
            FileUtil.writeLines(resPath + modelName + ".params", lines);
    
            // lda.phi K*V
            BufferedWriter writer = new BufferedWriter(new FileWriter(resPath
                    + modelName + ".phi"));
            for (int i = 0; i < K; i++) {
                for (int j = 0; j < V; j++) {
                    writer.write(phi[i][j] + "\t");
                }
                writer.write("\n");
            }
            writer.close();
    
            // lda.theta M*K
            writer = new BufferedWriter(new FileWriter(resPath + modelName
                    + ".theta"));
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < K; j++) {
                    writer.write(theta[i][j] + "\t");
                }
                writer.write("\n");
            }
            writer.close();
    
            // lda.tassign
            writer = new BufferedWriter(new FileWriter(resPath + modelName
                    + ".tassign"));
            for (int m = 0; m < M; m++) {
                for (int n = 0; n < doc[m].length; n++) {
                    writer.write(doc[m][n] + ":" + z[m][n] + "\t");
                }
                writer.write("\n");
            }
            writer.close();
    
            // lda.twords phi[][] K*V
            writer = new BufferedWriter(new FileWriter(resPath + modelName
                    + ".twords"));
            int topNum = 15; // Find the top 20 topic words in each topic
            for (int i = 0; i < K; i++) {
                List<Integer> tWordsIndexArray = new ArrayList<Integer>();
                for (int j = 0; j < V; j++) {
                    tWordsIndexArray.add(new Integer(j));
                }
                Collections.sort(tWordsIndexArray,
                        new LdaModel.ArrayDoubleComparator(phi[i]));
                writer.write("topic " + i + "\t:\t");
                for (int t = 0; t < topNum; t++) {
                    writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t))
                            + " " + phi[i][tWordsIndexArray.get(t)] + "\t");
                }
                writer.write("\n");
            }
            writer.close();
        }
    
        // save topic "word1:f1;word2:f2"
        public void saveTopic(Documents docSet) {
            int topNum = 15;
            for (int i = 0; i < K; i++) {
                List<Integer> tWordsIndexArray = new ArrayList<Integer>();
                for (int j = 0; j < V; j++) {
                    tWordsIndexArray.add(new Integer(j));
                }
                Collections.sort(tWordsIndexArray,
                        new LdaModel.ArrayDoubleComparator(phi[i]));
                TbTopic tbTopic = new TbTopic();
                tbTopic.setId(i);
                StringBuffer bf = new StringBuffer();
                for (int t = 0; t < topNum; t++) {
                    bf.append(docSet.indexToTermMap.get(tWordsIndexArray.get(t)));
                    bf.append(":");
                    bf.append(phi[i][tWordsIndexArray.get(t)]);
                    bf.append(";");
                }
                tbTopic.setWords(bf.toString());
                DocDBUtil.saveTbTopic(tbTopic);
            }
        }
    
        // save TbDistopic "topicId1:f1;TopicId2:f1"
        public void saveDisTopic(Documents docSet) {
            int topicNum = 3;
            for (int i = 0; i < M; i++) {
                int disId = Integer.parseInt(docSet.docs.get(i).docName);
                TbDisTopic tbDisTopic = new TbDisTopic();
                tbDisTopic.setId(disId);
                List<Integer> topicIndexArray = new ArrayList<Integer>();
                for (int j = 0; j < K; j++) {
                    topicIndexArray.add(new Integer(j));
                }
                Collections.sort(topicIndexArray,
                        new LdaModel.ArrayDoubleComparator(theta[i]));
                String topicIds = "";
                for (int t = 0; t < topicNum; t++) {
                    topicIds += topicIndexArray.get(t);
                    topicIds += ":";
                    topicIds += theta[i][topicIndexArray.get(t)];
                    topicIds += ";";
                }
                tbDisTopic.setTopic(topicIds);
                DocDBUtil.saveTbDisTopic(tbDisTopic);
            }
        }
    
        public class ArrayDoubleComparator implements Comparator<Integer> {
            private double[] sortProb; // Store probability of each word in topic k
    
            public ArrayDoubleComparator(double[] sortProb) {
                this.sortProb = sortProb;
            }
    
            @Override
            public int compare(Integer o1, Integer o2) {// Sort topic word index according to the probability of each word
                // in topic k
                if (sortProb[o1] > sortProb[o2])
                    return -1;
                else if (sortProb[o1] < sortProb[o2])
                    return 1;
                else
                    return 0;
            }
        }
    }

    核心代码还是写的不错的

    public class Documents {
    
        ArrayList<Document> docs;//.size()=M
        Map<String, Integer> termToIndexMap;//.size()=V,所有的词,没有重复,词--序号
        ArrayList<String> indexToTermMap;//序号--词,序号就是数组号。相当于词典,给定序号,就能找到词
        Map<String, Integer> termCountMap;//词频
    
        public Documents() {
            docs = new ArrayList<Document>();
            termToIndexMap = new HashMap<String, Integer>();
            indexToTermMap = new ArrayList<String>();
            termCountMap = new HashMap<String, Integer>();
        }
    
        public void getDocsFromDB() {
            List<TbDiseases> diseases = DocDBUtil.getTbDiseases();
            for (TbDiseases disea : diseases) {
                String content = disea.getSymptomDetail();
                String docName = disea.getId() + "";
                Document doc = new Document(content, termToIndexMap, indexToTermMap, termCountMap, docName);
                docs.add(doc);
            }
        }
    
        public static class Document {
            String docName;
            //文档的词,去除了停用词和干扰词,保存的int是indexToTermMap中对应词的序号
            int[] docWords;
    
            public Document(String content, Map<String, Integer> termToIndexMap,
                    ArrayList<String> indexToTermMap,
                    Map<String, Integer> termCountMap, String docName) {
                this.setDocName(docName);
                ArrayList<String> words = DocDBUtil.getWordsFromSentence(content);
                // Transfer word to index
                this.docWords = new int[words.size()];
                for (int i = 0; i < words.size(); i++) {
                    String word = words.get(i);
                    if (!termToIndexMap.containsKey(word)) {
                        int newIndex = termToIndexMap.size();
                        termToIndexMap.put(word, newIndex);
                        indexToTermMap.add(word);
                        termCountMap.put(word, new Integer(1));
                        docWords[i] = newIndex;
                    } else {
                        docWords[i] = termToIndexMap.get(word);
                        termCountMap.put(word, termCountMap.get(word) + 1);
                    }
                }
                words.clear();
            }
    
            public void setDocName(String docName) {
                this.docName = docName;
            }
    
            public String getDocName() {
                return docName;
            }
        }//Document
    }

    文档来自数据库的

  • 相关阅读:
    安装.NET FRAMEWORK 4.5安装进度条回滚之后发生严重错误 代码0x80070643
    C#远程时间同步助手软件设计
    Win7+Ubuntu双系统安装完成后时间不一致相差大概8小时
    php中类的不定参数使用示例
    php读写xml基于DOMDocument方法
    php写的非常简单的文件浏览器
    php封装的sqlite操作类
    phpstudy中apache的默认根目录的配置
    实现基于最近邻内插和双线性内插的图像缩放C++实现
    【STL深入理解】vector
  • 原文地址:https://www.cnblogs.com/549294286/p/3019473.html
Copyright © 2011-2022 走看看