首页 > 代码库 > 自己实现的SVM源码

自己实现的SVM源码

首先是DATA类

import java.awt.print.Printable;import java.io.File;import java.io.FileNotFoundException;import java.util.ArrayList;import java.util.HashMap;import java.util.List;import java.util.Map;import java.util.Scanner;public class Data {public Map<List<Double>, Integer> getTrainData() {	Map<List<Double>, Integer> data=http://www.mamicode.com/new HashMap, Integer>();"G://download//testSet.txt"));		while(in.hasNextLine())		{			String str =in.nextLine();			String []strs=str.trim().split("\t");			List<Double> pointTmp=new ArrayList<>();			for(int i=0;i<strs.length-1;i++)				pointTmp.add(Double.parseDouble(strs[i]));			data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));		}	} catch (FileNotFoundException e) {		// TODO: handle exception		e.printStackTrace();	}		return data;}public static void main(String[] args){	Data data=http://www.mamicode.com/new Data();>

  SVM类:

import java.awt.print.Printable;import java.io.FileNotFoundException;import java.io.ObjectInputStream.GetField;import java.io.PrintWriter;import java.util.ArrayList;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Random;import java.util.Map.Entry;public class SVM {	private List<ArrayList<Double>> trainData;	private List<Integer> labelTrainData;	private double sigma;	private double C;	private List<Double> alpha;	private double b;	private List<Double> E;	private int N;	private int dim;	private double tol;	private double eta;	private double eps;	private double eps2;		public boolean satisfyKkt(int id)	{		double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)		if(Math.abs(this.alpha.get(id))<=this.eps)		{			if(ypgx-1<-this.tol) return false;		}		else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)		{			if(ypgx-1>this.tol) return false;		}		else {			if(Math.abs(ypgx-1)>this.tol) return false;		}		return true;	}		public void updateE() {				for(int i=0;i<this.N;i++)		{			double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);			this.E.set(i, Ei);		}	}		public double kernelLinear(List<Double> X,List<Double> Y) {		//linear kernel function		int len=Y.size();		double s=0;		for(int i=0;i<len;i++)			s+=X.get(i)*Y.get(i);		return s;	}				public double kernelRBF(List<Double> X,List<Double> Y)	{		//gauss kernel function				int len=Y.size();		double s=0;		for(int i=0;i<len;i++)			s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));		s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));		return s;	}			public double getGx(List<Double> X)	{		//calculate wx+b value		double s=0;		for(int i=0;i<this.N;i++)		{			//for debug			double debug1=kernelRBF(X, this.trainData.get(i));			double debug2=this.alpha.get(i);						s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));		}		s+=this.b;		return s;	}		public int update(int x1,int x2)	{		double low=0;		double high=0;		if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))		{			low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);			high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));		}		else		{			low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));			high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);		}		double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;		double newAlpha1=0;				if(newAlpha2>high) newAlpha2=high;		else if(newAlpha2<low) newAlpha2=low;		newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2);				if(Math.abs(newAlpha1)<=this.eps)			newAlpha1=0;		if(Math.abs(newAlpha2)<=this.eps)			newAlpha2=0;		if(Math.abs(newAlpha1-this.C)<=this.eps)			newAlpha1=this.C;		if(Math.abs(newAlpha2-this.C)<=this.eps)			newAlpha2=this.C;		if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)			return 0;		if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)			return 0;				double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;		double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b;				if(newAlpha1>0&&newAlpha1<this.C)			this.b=b1;		else if(newAlpha2>0&&newAlpha2<this.C)			this.b=b2;		else			this.b=(b1+b2)/2;				this.alpha.set(x1,newAlpha1);		this.alpha.set(x2,newAlpha2);		updateE();		return 1;	}	public int selectAlpha2(int x1) {				int x2=-1;		double maxDiff=-1;		//first select x2 from 0<a<c to max(E(x1)-E(x2))				for(int i=0;i<this.N;++i)		{			if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;			double diff=Math.abs(this.E.get(x1)-this.E.get(i));			if(diff>maxDiff)			{				maxDiff=diff;				x2=i;			}		}				//second calculate eta (eta!=0)		if(x2!=-1)		{			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));			if(eta!=0) return x2;		}				//third if cannot find in the whole train set		for(int i=0;i<this.N;i++)		{			if(i==x1) continue;			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));			if(Math.abs(this.eta)>this.eps) return i;		}		return -1;					}		public void SMO() {		//to solve alpha		int numChanged=0;		int cnt=0;		while(true)		{			cnt++;			System.out.println(cnt);						numChanged=0;			for(int x1=0;x1<this.N;++x1)			{				if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;				if(!satisfyKkt(x1))				{					int x2=selectAlpha2(x1);					if(x2==-1) continue;					numChanged+=update(x1, x2);				}			}			if(numChanged==0)			{				for(int x1=0;x1<this.N;++x1)				{					if(!satisfyKkt(x1))					{						int x2=selectAlpha2(x1);						if(x2==-1) continue;						update(x1, x2);						numChanged++;					}				}			}			if(numChanged==0)				break;						}	}		public SVM() {		//load train data				Data data=http://www.mamicode.com/new Data();"g://download//resultpoints.txt");		for(int i=0;i<s.N;i++)		{			out.write((s.trainData.get(i).get(0)).toString());			out.write("\t");			out.write((s.trainData.get(i).get(1)).toString());			out.write("\t");			out.write(Integer.toString(s.predict(s.trainData.get(i))));			out.write("\n");		}		out.close();		//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction		double w[]=s.getLinearW();		System.out.println(w[0]+" "+w[1]+" "+s.b+"======");	}}

  

用线性核函数实现的SVM的到的分类结果

技术分享

 画图,是用python代码

from numpy import *  import matplotlib  import matplotlib.pyplot as plt  import numpy as npwith open("g://download/myresult.txt") as f1:    data=f1.readlines();        plt.figure(figsize=(8, 5), dpi=80)       axes = plt.subplot(111)       type1_x = []      type1_y = []      type2_x = []      type2_y = []     for line in data:        x=line.strip().split(‘\t‘);        x1=float(x[0])        x2=float(x[1])        x3=int(x[2])                if x3==1:            type1_x.append(x1)            type1_y.append(x2)        else:            type2_x.append(x1)            type2_y.append(x2)            type1 = axes.scatter(type1_x, type1_y,s=40, c=‘red‘ )       type2 = axes.scatter(type2_x, type2_y, s=40, c=‘green‘)          W1 = 0.8148005405344305      W2 = -0.27263471796762484      B = -3.8392586254518437      x = np.linspace(-4,10,200)      y = (-W1/W2)*x+(-B/W2)      axes.plot(x,y,‘b‘,lw=3)         plt.xlabel(‘x1‘)       plt.ylabel(‘x2‘)           axes.legend((type1, type2), (‘0‘, ‘1‘),loc=1)       plt.show()  #0.8148005405344305 -0.27263471796762484 -3.8392586254518437

  用高斯核,当C=6,sigma=1时候

技术分享

高斯核,当c=0.5,sigma=1时候

技术分享

 

当C=0.5,sigma=12时候

技术分享

 

 

说明C的大小和sigma的大小对高斯核影响是很大的

自己实现的SVM源码