zoukankan      html  css  js  c++  java
  • Eclipse下mallet使用的方法

    Mallet是Umass大牛开发的一个关于统计自然语言处理的l的开源库,很好的一个东西。可以用来学topic model,训练ME模型等。对于开发者来说,其官网的技术文档是非常有效的。

    mallet下载地址,浏览开发者文档,只需点击相应的“Developer's Guide”。

    下面以开发一个简单的最大熵分类模型为例,可参考文档

    首先下载mallet工具包,该工具包中包含代码和jar包,简单起见,我们导入mallet-2.0.7dist下的mallet.jar和mallet-deps.jar,导入jar包过程为:项目右击->Properties->Java Build Path->Libraries,点击“Add JARs”,在路径中选取相应的jar包即可。

    新建Maxent类,代码如下:

    import java.io.File;
    import java.io.FileInputStream;
    import java.io.FileNotFoundException;
    import java.io.FileOutputStream;
    import java.io.FileReader;
    import java.io.IOException;
    import java.io.ObjectInputStream;
    import java.io.ObjectOutputStream;
    import java.io.Serializable;
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import cc.mallet.classify.Classifier;
    import cc.mallet.classify.ClassifierTrainer;
    import cc.mallet.classify.MaxEntTrainer;
    import cc.mallet.classify.Trial;
    import cc.mallet.pipe.iterator.CsvIterator;
    import cc.mallet.types.Alphabet;
    import cc.mallet.types.FeatureVector;
    import cc.mallet.types.Instance;
    import cc.mallet.types.InstanceList;
    import cc.mallet.types.Label;
    import cc.mallet.types.LabelAlphabet;
    import cc.mallet.types.Labeling;
    import cc.mallet.util.Randoms;
    
    public class Maxent implements Serializable{
        
        //Train a classifier
        public static Classifier trainClassifier(InstanceList trainingInstances) {
            // Here we use a maximum entropy (ie polytomous logistic regression) classifier.                                                 
            ClassifierTrainer trainer = new MaxEntTrainer();
            return trainer.train(trainingInstances);
        }
        
        //save a trained classifier/write a trained classifier to disk
        public void saveClassifier(Classifier classifier,String savePath) throws IOException{
            ObjectOutputStream oos=new ObjectOutputStream(new FileOutputStream(savePath));
            oos.writeObject(classifier);
            oos.flush();
            oos.close();        
        }
        
        //restore a saved classifier
        public Classifier loadClassifier(String savedPath) throws FileNotFoundException, IOException, ClassNotFoundException{                                              
            // Here we load a serialized classifier from a file.
            Classifier classifier;
            ObjectInputStream ois = new ObjectInputStream (new FileInputStream (new File(savedPath)));
            classifier = (Classifier) ois.readObject();
            ois.close();
            return classifier;
        }
        
        //predict & evaluate
        public String predict(Classifier classifier,Instance testInstance){
            Labeling labeling = classifier.classify(testInstance).getLabeling();
            Label label = labeling.getBestLabel();
            return (String)label.getEntry();
        }
        
        public void evaluate(Classifier classifier, String testFilePath) throws IOException {
            InstanceList testInstances = new InstanceList(classifier.getInstancePipe());                                                                                                                                                                
            
            //format of input data:[name] [label] [data ... ]                                                                    
            CsvIterator reader = new CsvIterator(new FileReader(new File(testFilePath)),"(\w+)\s+(\w+)\s+(.*)",3, 2, 1);  // (data, label, name) field indices               
    
            // Add all instances loaded by the iterator to our instance list
            testInstances.addThruPipe(reader);
            Trial trial = new Trial(classifier, testInstances);
    
            //evaluation metrics.precision, recall, and F1
            System.out.println("Accuracy: " + trial.getAccuracy());                                                      
            System.out.println("F1 for class 'good': " + trial.getF1("good"));
            System.out.println("Precision for class '" +
                               classifier.getLabelAlphabet().lookupLabel(1) + "': " +
                               trial.getPrecision(1));
        }
    
        //perform n-fold cross validation
         public static Trial testTrainSplit(MaxEntTrainer trainer, InstanceList instances) {
             int TRAINING = 0;
             int TESTING = 1;
             int VALIDATION = 2;
         
             // Split the input list into training (90%) and testing (10%) lists.
             InstanceList[] instanceLists = instances.split(new Randoms(), new double[] {0.9, 0.1, 0.0});
             Classifier classifier = trainClassifier(instanceLists[TRAINING]);
             return new Trial(classifier, instanceLists[TESTING]);
          }
         
        public static void main(String[] args) throws FileNotFoundException,IOException{
            //define training samples
            Alphabet featureAlphabet = new Alphabet();//特征词典
            LabelAlphabet targetAlphabet = new LabelAlphabet();//类标词典
            targetAlphabet.lookupIndex("positive");
            targetAlphabet.lookupIndex("negative");
            targetAlphabet.lookupIndex("neutral");
            targetAlphabet.stopGrowth();
            featureAlphabet.lookupIndex("f1");
            featureAlphabet.lookupIndex("f2");
            featureAlphabet.lookupIndex("f3");
            InstanceList trainingInstances = new InstanceList (featureAlphabet,targetAlphabet);//实例集对象
            final int size = targetAlphabet.size();
            double[] featureValues1 = {1.0, 0.0, 0.0};
            double[] featureValues2 = {2.0, 0.0, 0.0};
            double[] featureValues3 = {0.0, 1.0, 0.0};
            double[] featureValues4 = {0.0, 0.0, 1.0};
            double[] featureValues5 = {0.0, 0.0, 3.0};
            String[] targetValue = {"positive","positive","neutral","negative","negative"};
            List<double[]> featureValues = Arrays.asList(featureValues1,featureValues2,featureValues3,featureValues4,featureValues5); 
            int i = 0;
            for(double[]featureValue:featureValues){
                FeatureVector featureVector = new FeatureVector(featureAlphabet,
                        (String[])targetAlphabet.toArray(new String[size]),featureValue);//change list to array
                Instance instance = new Instance (featureVector,targetAlphabet.lookupLabel(targetValue[i]), "xxx",null);
                i++;
                trainingInstances.add(instance);
            }
             
            Maxent maxent = new Maxent();
            Classifier maxentclassifier = maxent.trainClassifier(trainingInstances);
            //loading test examples
            double[] testfeatureValues = {0.5, 0.5, 6.0};
            FeatureVector testfeatureVector = new FeatureVector(featureAlphabet,
                    (String[])targetAlphabet.toArray(new String[size]),testfeatureValues);
            //new instance(data,target,name,source)
            Instance testinstance = new Instance (testfeatureVector,targetAlphabet.lookupLabel("negative"), "xxx",null);
            System.out.print(maxent.predict(maxentclassifier, testinstance));
            //maxent.evaluate(maxentclassifier, "resource/testdata.txt");
        }
    }

    说明:trainingInstances为训练样本,testinstance为测试样本,该程序的执行结果为“negative”。

  • 相关阅读:
    Cookie 干货
    element-ui 框架中使用 NavMenu 导航菜单组件时,点击一个子菜单会出现多个子菜单同时展开或折叠?
    数组遍历的方法
    前端网页字体
    样式小收藏:完成、错误、提示动态图标样式
    多语言网站利器 rel="alternate" hreflang="x"
    网页中文章显示一部分,然后“查看全文”
    仿水滴筹中快捷留言祝福、随机生成祝福
    TypeScript知识点
    前端项目经验
  • 原文地址:https://www.cnblogs.com/tec-vegetables/p/4182705.html
Copyright © 2011-2022 走看看