首页 > 代码库 > Eclipse下mallet使用的方法

Eclipse下mallet使用的方法

Mallet是Umass大牛开发的一个关于统计自然语言处理的l的开源库,很好的一个东西。可以用来学topic model,训练ME模型等。对于开发者来说,其官网的技术文档是非常有效的。

mallet下载地址,浏览开发者文档,只需点击相应的“Developer‘s Guide”。

下面以开发一个简单的最大熵分类模型为例,可参考文档

首先下载mallet工具包,该工具包中包含代码和jar包,简单起见,我们导入mallet-2.0.7\dist下的mallet.jar和mallet-deps.jar,导入jar包过程为:项目右击->Properties->Java Build Path->Libraries,点击“Add JARs”,在路径中选取相应的jar包即可。

新建Maxent类,代码如下:

import java.io.File;import java.io.FileInputStream;import java.io.FileNotFoundException;import java.io.FileOutputStream;import java.io.FileReader;import java.io.IOException;import java.io.ObjectInputStream;import java.io.ObjectOutputStream;import java.io.Serializable;import java.util.ArrayList;import java.util.Arrays;import java.util.List;import cc.mallet.classify.Classifier;import cc.mallet.classify.ClassifierTrainer;import cc.mallet.classify.MaxEntTrainer;import cc.mallet.classify.Trial;import cc.mallet.pipe.iterator.CsvIterator;import cc.mallet.types.Alphabet;import cc.mallet.types.FeatureVector;import cc.mallet.types.Instance;import cc.mallet.types.InstanceList;import cc.mallet.types.Label;import cc.mallet.types.LabelAlphabet;import cc.mallet.types.Labeling;import cc.mallet.util.Randoms;public class Maxent implements Serializable{        //Train a classifier    public static Classifier trainClassifier(InstanceList trainingInstances) {        // Here we use a maximum entropy (ie polytomous logistic regression) classifier.                                                         ClassifierTrainer trainer = new MaxEntTrainer();        return trainer.train(trainingInstances);    }        //save a trained classifier/write a trained classifier to disk    public void saveClassifier(Classifier classifier,String savePath) throws IOException{        ObjectOutputStream oos=new ObjectOutputStream(new FileOutputStream(savePath));        oos.writeObject(classifier);        oos.flush();        oos.close();            }        //restore a saved classifier    public Classifier loadClassifier(String savedPath) throws FileNotFoundException, IOException, ClassNotFoundException{                                                      // Here we load a serialized classifier from a file.        Classifier classifier;        ObjectInputStream ois = new ObjectInputStream (new FileInputStream (new File(savedPath)));        classifier = (Classifier) ois.readObject();        ois.close();        return classifier;    }        //predict & evaluate    public String predict(Classifier classifier,Instance testInstance){        Labeling labeling = classifier.classify(testInstance).getLabeling();        Label label = labeling.getBestLabel();        return (String)label.getEntry();    }        public void evaluate(Classifier classifier, String testFilePath) throws IOException {        InstanceList testInstances = new InstanceList(classifier.getInstancePipe());                                                                                                                                                                                //format of input data:[name] [label] [data ... ]                                                                            CsvIterator reader = new CsvIterator(new FileReader(new File(testFilePath)),"(\\w+)\\s+(\\w+)\\s+(.*)",3, 2, 1);  // (data, label, name) field indices                       // Add all instances loaded by the iterator to our instance list        testInstances.addThruPipe(reader);        Trial trial = new Trial(classifier, testInstances);        //evaluation metrics.precision, recall, and F1        System.out.println("Accuracy: " + trial.getAccuracy());                                                              System.out.println("F1 for class ‘good‘: " + trial.getF1("good"));        System.out.println("Precision for class ‘" +                           classifier.getLabelAlphabet().lookupLabel(1) + "‘: " +                           trial.getPrecision(1));    }    //perform n-fold cross validation     public static Trial testTrainSplit(MaxEntTrainer trainer, InstanceList instances) {         int TRAINING = 0;         int TESTING = 1;         int VALIDATION = 2;              // Split the input list into training (90%) and testing (10%) lists.         InstanceList[] instanceLists = instances.split(new Randoms(), new double[] {0.9, 0.1, 0.0});         Classifier classifier = trainClassifier(instanceLists[TRAINING]);         return new Trial(classifier, instanceLists[TESTING]);      }         public static void main(String[] args) throws FileNotFoundException,IOException{        //define training samples        Alphabet featureAlphabet = new Alphabet();//特征词典        LabelAlphabet targetAlphabet = new LabelAlphabet();//类标词典        targetAlphabet.lookupIndex("positive");        targetAlphabet.lookupIndex("negative");        targetAlphabet.lookupIndex("neutral");        targetAlphabet.stopGrowth();        featureAlphabet.lookupIndex("f1");        featureAlphabet.lookupIndex("f2");        featureAlphabet.lookupIndex("f3");        InstanceList trainingInstances = new InstanceList (featureAlphabet,targetAlphabet);//实例集对象        final int size = targetAlphabet.size();        double[] featureValues1 = {1.0, 0.0, 0.0};        double[] featureValues2 = {2.0, 0.0, 0.0};        double[] featureValues3 = {0.0, 1.0, 0.0};        double[] featureValues4 = {0.0, 0.0, 1.0};        double[] featureValues5 = {0.0, 0.0, 3.0};        String[] targetValue = {"positive","positive","neutral","negative","negative"};        List<double[]> featureValues = Arrays.asList(featureValues1,featureValues2,featureValues3,featureValues4,featureValues5);         int i = 0;        for(double[]featureValue:featureValues){            FeatureVector featureVector = new FeatureVector(featureAlphabet,                    (String[])targetAlphabet.toArray(new String[size]),featureValue);//change list to array            Instance instance = new Instance (featureVector,targetAlphabet.lookupLabel(targetValue[i]), "xxx",null);            i++;            trainingInstances.add(instance);        }                 Maxent maxent = new Maxent();        Classifier maxentclassifier = maxent.trainClassifier(trainingInstances);        //loading test examples        double[] testfeatureValues = {0.5, 0.5, 6.0};        FeatureVector testfeatureVector = new FeatureVector(featureAlphabet,                (String[])targetAlphabet.toArray(new String[size]),testfeatureValues);        //new instance(data,target,name,source)        Instance testinstance = new Instance (testfeatureVector,targetAlphabet.lookupLabel("negative"), "xxx",null);        System.out.print(maxent.predict(maxentclassifier, testinstance));        //maxent.evaluate(maxentclassifier, "resource/testdata.txt");    }}

说明:trainingInstances为训练样本,testinstance为测试样本,该程序的执行结果为“negative”。

 

Eclipse下mallet使用的方法