首页 > 代码库 > 京东评论情感分类器(基于bag-of-words模型)

京东评论情感分类器(基于bag-of-words模型)

京东评论情感分类器(基于bag-of-words模型)


最近在本来在研究paraVector模型,想拿bag-of-words来做对比。

数据集是京东的评论,经过人工挑选,选出一批正面和负面的评论。

实验的数据量不大,340条正面,314条负面。我一般拿200正面和200负面做训练,剩下做测试。


做着做着,领悟了一些机器学习的道理。发现,对于不同的数据集,效果是不同的。

对于特定的数据集,随便拿来一套模型可能并不适用。


对于这些评论,我感觉就是bag-of-words模型靠谱点。

因为这些评论的特点是语句简短,关键词重要。

paraVector模型感觉比较擅长长文本的分析,注重上下文。


其实我还结合了两个模型来做一个新的模型,准确率有点提高,但是不大。可能我数据量太少了。


整理了一下思路,做了个评论情感分类的demo。

特征抽取是bag-of-words模型。

分类器是自己想的一个模型,结合了knn和kmeans的思想。根据对于正负样本的训练集分别求出两个聚类中心,每次新样本进来,跟两个中心做距离比较。


以下是demo的代码:

import java.util.Scanner;


public class BowInterTest {

	public static void main(String[] args) throws Throwable 
	{
		// TODO Auto-generated method stub
		BowModel bm = new BowModel("/media/linger/G/sources/comment/test/all");//all=good+bad
		double[][] good = bm.generateFeature("/media/linger/G/sources/comment/test/good",340);
		double[][] bad = bm.generateFeature("/media/linger/G/sources/comment/test/bad",314);
		bm.train(good,0,200,bad,0,200);//指定训练数据
		//bm.test(good, 200, 340, bad, 200, 314);//指定测试数据
				
		//交互模式
		Scanner sc = new Scanner(System.in);
		while(sc.hasNext())
		{
			String doc = sc.nextLine();
			double[] fea = bm.docFea(doc);
			Norm.arrayNorm2(fea);
			double re = bm.predict(fea);
			if(re<0)
			{
				System.out.println("good:"+re);
			}
			else 
			{
				System.out.println("bad:"+re);
			}
			
		}
		

	}

}

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.StringTokenizer;



public class BowModel extends KnnCoreModel
{

	
	Dict dict;
	DocFeatureFactory dff;
	
	public BowModel(String path) throws IOException 
	{
		dict = new Dict();
		dict.loadFromLocalFile(path);		
		dff = new DocFeatureFactory(dict.getWord2Index());
	}
	
	public double[] docFea(String doc)
	{
		return dff.getFeature(doc);
	}
	public double[][] generateFeature(String docsFile,int docNum) throws IOException
	{
		double[][] featureTable = new double[docNum][];
		int docIndex=0;
		File file = new File(docsFile);
		BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8"));
		while(true)
		{
			String line=br.readLine();
			if(line == null)
				break;
			featureTable[docIndex++] = dff.getFeature(line);
		}
		br.close();	
		return featureTable;
	}


}


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.StringTokenizer;
import java.util.Map.Entry;

public class Dict 
{
	HashMap<String,Integer> word2Index =null;
	Hashtable<String,Integer> word2Count = null;
	void loadFromLocalFile(String path) throws IOException
	{
		word2Index = new HashMap<String,Integer>();
		word2Count = new Hashtable<String,Integer>();
		int index = 0;
		File file = new File(path);
		BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8"));
		while(true)
		{
			String line=br.readLine();
			if(line == null)
				break;
			StringTokenizer tokenizer=new StringTokenizer(line," ");
			while(tokenizer.hasMoreElements())
			{
				String term=tokenizer.nextToken();
				if(word2Count.containsKey(term))
				{
					
					int freq=word2Count.get(term)+1;
					word2Count.put(term, freq);
					
				}
				else
				{
					word2Count.put(term, 1);
					word2Index.put(term, index++);
				}
			}
		}
		br.close();
	}
	
	public HashMap<String,Integer> getWord2Index() 
	{
		return word2Index;
	}
	
	public void print()
	{	
		Iterator<Entry<String, Integer>> iter=word2Count.entrySet().iterator();
		while(iter.hasNext())
		{
			Entry<String,Integer> item=(Entry<String,Integer>)iter.next();
			if(item.getValue()>30)
			System.out.printf("%s,%d\n",item.getKey(),item.getValue());
		}
	}
	public static void main(String[] args) throws IOException 
	{
		// TODO Auto-generated method stub
		Dict dict = new Dict();
		dict.loadFromLocalFile("/media/linger/G/sources/comment/test/all");
		dict.print();

	}

}


import java.util.HashMap;
import java.util.StringTokenizer;

public class DocFeatureFactory 
{
	HashMap<String,Integer> word2Index;
	double[] feature;
	int dim;
	public DocFeatureFactory(HashMap<String,Integer> w2i)
	{
		word2Index = w2i;
		dim = w2i.size();
	}
	
	double[] getFeature(String doc)
	{
		feature = new double[dim];
		int wordNum=0;
		//while(wordNum<25)//这个作用跟规范化的一样啊
		//{
			StringTokenizer tokenizer=new StringTokenizer(doc," ");
			while(tokenizer.hasMoreElements())
			{
				wordNum++;
				String term =tokenizer.nextToken();
				Integer index = word2Index.get(term);
				if(index==null) continue;
				feature[index]++;				
			}
		//}	
		return feature;
	}
	
	public static void main(String[] args) 
	{
		// TODO Auto-generated method stub

	}

}



public class KnnCoreModel 
{
	double[] good_standard ;
	double[] bad_standard ;
	public void train(double[][] good,int train_good_start,int train_good_end,
					  double[][] bad,int train_bad_start,int train_bad_end) 
	{
		//double[][] good = generateFeature("/media/linger/G/sources/comment/test/good",340);
		//double[][] bad = generateFeature("/media/linger/G/sources/comment/test/bad",314);
			
		//double[] arv = new double[good[0].length];
		//double[] var = new double[good[0].length];
		
		//2范式归一化
		Norm.tableNorm2(good);
		Norm.tableNorm2(bad);
		good_standard = new double[good[0].length];
		bad_standard = new double[bad[0].length];
		for(int i=train_good_start;i<train_good_end;i++)
		{
			for(int j=0;j<good[i].length;j++)
			{
				good_standard[j]+=good[i][j];
			}
		}
		
		//System.out.println("\ngood core:");
		for(int j=0;j<good_standard.length;j++)
		{
			good_standard[j]/=(train_good_end-train_good_start);
			//System.out.printf("%f,",good_standard[j]);
			
		}
		
		for(int i=train_bad_start;i<train_bad_end;i++)
		{	
			for(int j=0;j<bad[i].length;j++)
			{
				bad_standard[j]+=bad[i][j];
			}
		}
		//System.out.println("\nbad core:");
		for(int j=0;j<bad_standard.length;j++)
		{
			bad_standard[j]/=(train_bad_end-train_bad_start);
			//System.out.printf("%f,",bad_standard[j]);
		}
	}
	
	public void test(double[][] good,int test_good_start,int test_good_end,
	        double[][] bad,int test_bad_start,int test_bad_end)
	{	
		Norm.tableNorm2(good);
		Norm.tableNorm2(bad);
		int error=0;
		double good_dis;
		double bad_dis;
		//test
		for(int i=test_good_start;i<test_good_end;i++)
		{
				good_dis= distance(good[i],good_standard);
				bad_dis = distance(good[i],bad_standard);
			//good_dis= allDistance(good[i],good,train_good_start,train_good_end);
			//bad_dis = allDistance(good[i],bad,train_bad_start,train_bad_end);
				double dis= good_dis-bad_dis;
				if(dis>0)
				{
					error++;
					System.out.println("-:"+(dis));
				}
				else 
				{
					System.out.println("+:"+(dis));
				}					
		}
		
		for(int i=test_bad_start;i<test_bad_end;i++)
		{
			good_dis= distance(bad[i],good_standard);
			bad_dis = distance(bad[i],bad_standard);
			//good_dis= allDistance(bad[i],good,train_good_start,train_good_end);
			//bad_dis = allDistance(bad[i],bad,train_bad_start,train_bad_end);
			double dis= good_dis-bad_dis;
			if(dis>0)
			{
				
				System.out.println("+:"+(dis));
			}
			else 
			{
				error++;
				System.out.println("-:"+(dis));
			}
		}
		
		int count = (test_good_end-test_good_start+test_bad_end-test_bad_start);
		System.out.println("\nerror:"+error+",total:"+count);
		System.out.println("error rate:"+(double)error/count);
		System.out.println("acc rate:"+(double)(count-error)/count);
	}
	
	public double predict(double[] fea)
	{
		double good_dis = distance(fea,good_standard);
		double bad_dis = distance(fea,bad_standard);
		return good_dis-bad_dis;
	}
	
	private double distance(double[] src,double[] dst)
	{
		double sum=0;
		if(src.length!=dst.length)
		{
			System.out.println("size not right!");
			return sum;
		}
		for(int i=0;i<src.length;i++)
		{
			sum+=(dst[i]-src[i])*(dst[i]-src[i]);
		}
		//return Math.sqrt(sum);		
		return sum;
	}
	private  double allDistance(double[]src,double[][] trainSet,int start,int end)
	{
		double sum=0;
		for(int i=start;i<end && i<trainSet.length;i++)
		{
			sum+=distance(src,trainSet[i]);
		}
		return sum;
	}
	

}


本文作者:linger

本文链接:http://blog.csdn.net/lingerlanlan/article/details/38418277