首页 > 代码库 > 自己实现的SVM源码
自己实现的SVM源码
首先是DATA类
import java.awt.print.Printable;import java.io.File;import java.io.FileNotFoundException;import java.util.ArrayList;import java.util.HashMap;import java.util.List;import java.util.Map;import java.util.Scanner;public class Data {public Map<List<Double>, Integer> getTrainData() { Map<List<Double>, Integer> data=http://www.mamicode.com/new HashMap, Integer>();"G://download//testSet.txt")); while(in.hasNextLine()) { String str =in.nextLine(); String []strs=str.trim().split("\t"); List<Double> pointTmp=new ArrayList<>(); for(int i=0;i<strs.length-1;i++) pointTmp.add(Double.parseDouble(strs[i])); data.put(pointTmp, Integer.parseInt(strs[strs.length-1])); } } catch (FileNotFoundException e) { // TODO: handle exception e.printStackTrace(); } return data;}public static void main(String[] args){ Data data=http://www.mamicode.com/new Data();>
SVM类:
import java.awt.print.Printable;import java.io.FileNotFoundException;import java.io.ObjectInputStream.GetField;import java.io.PrintWriter;import java.util.ArrayList;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Random;import java.util.Map.Entry;public class SVM { private List<ArrayList<Double>> trainData; private List<Integer> labelTrainData; private double sigma; private double C; private List<Double> alpha; private double b; private List<Double> E; private int N; private int dim; private double tol; private double eta; private double eps; private double eps2; public boolean satisfyKkt(int id) { double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x) if(Math.abs(this.alpha.get(id))<=this.eps) { if(ypgx-1<-this.tol) return false; } else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps) { if(ypgx-1>this.tol) return false; } else { if(Math.abs(ypgx-1)>this.tol) return false; } return true; } public void updateE() { for(int i=0;i<this.N;i++) { double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i); this.E.set(i, Ei); } } public double kernelLinear(List<Double> X,List<Double> Y) { //linear kernel function int len=Y.size(); double s=0; for(int i=0;i<len;i++) s+=X.get(i)*Y.get(i); return s; } public double kernelRBF(List<Double> X,List<Double> Y) { //gauss kernel function int len=Y.size(); double s=0; for(int i=0;i<len;i++) s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i)); s=Math.exp(-s/(2*Math.pow(this.sigma, 2))); return s; } public double getGx(List<Double> X) { //calculate wx+b value double s=0; for(int i=0;i<this.N;i++) { //for debug double debug1=kernelRBF(X, this.trainData.get(i)); double debug2=this.alpha.get(i); s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i)); } s+=this.b; return s; } public int update(int x1,int x2) { double low=0; double high=0; if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2)) { low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C); high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1)); } else { low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1)); high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C); } double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta; double newAlpha1=0; if(newAlpha2>high) newAlpha2=high; else if(newAlpha2<low) newAlpha2=low; newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2); if(Math.abs(newAlpha1)<=this.eps) newAlpha1=0; if(Math.abs(newAlpha2)<=this.eps) newAlpha2=0; if(Math.abs(newAlpha1-this.C)<=this.eps) newAlpha1=this.C; if(Math.abs(newAlpha2-this.C)<=this.eps) newAlpha2=this.C; if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2) return 0; if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2) return 0; double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b; double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b; if(newAlpha1>0&&newAlpha1<this.C) this.b=b1; else if(newAlpha2>0&&newAlpha2<this.C) this.b=b2; else this.b=(b1+b2)/2; this.alpha.set(x1,newAlpha1); this.alpha.set(x2,newAlpha2); updateE(); return 1; } public int selectAlpha2(int x1) { int x2=-1; double maxDiff=-1; //first select x2 from 0<a<c to max(E(x1)-E(x2)) for(int i=0;i<this.N;++i) { if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue; double diff=Math.abs(this.E.get(x1)-this.E.get(i)); if(diff>maxDiff) { maxDiff=diff; x2=i; } } //second calculate eta (eta!=0) if(x2!=-1) { this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2)); if(eta!=0) return x2; } //third if cannot find in the whole train set for(int i=0;i<this.N;i++) { if(i==x1) continue; this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i)); if(Math.abs(this.eta)>this.eps) return i; } return -1; } public void SMO() { //to solve alpha int numChanged=0; int cnt=0; while(true) { cnt++; System.out.println(cnt); numChanged=0; for(int x1=0;x1<this.N;++x1) { if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue; if(!satisfyKkt(x1)) { int x2=selectAlpha2(x1); if(x2==-1) continue; numChanged+=update(x1, x2); } } if(numChanged==0) { for(int x1=0;x1<this.N;++x1) { if(!satisfyKkt(x1)) { int x2=selectAlpha2(x1); if(x2==-1) continue; update(x1, x2); numChanged++; } } } if(numChanged==0) break; } } public SVM() { //load train data Data data=http://www.mamicode.com/new Data();"g://download//resultpoints.txt"); for(int i=0;i<s.N;i++) { out.write((s.trainData.get(i).get(0)).toString()); out.write("\t"); out.write((s.trainData.get(i).get(1)).toString()); out.write("\t"); out.write(Integer.toString(s.predict(s.trainData.get(i)))); out.write("\n"); } out.close(); //if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction double w[]=s.getLinearW(); System.out.println(w[0]+" "+w[1]+" "+s.b+"======"); }}
用线性核函数实现的SVM的到的分类结果
画图,是用python代码
from numpy import * import matplotlib import matplotlib.pyplot as plt import numpy as npwith open("g://download/myresult.txt") as f1: data=f1.readlines(); plt.figure(figsize=(8, 5), dpi=80) axes = plt.subplot(111) type1_x = [] type1_y = [] type2_x = [] type2_y = [] for line in data: x=line.strip().split(‘\t‘); x1=float(x[0]) x2=float(x[1]) x3=int(x[2]) if x3==1: type1_x.append(x1) type1_y.append(x2) else: type2_x.append(x1) type2_y.append(x2) type1 = axes.scatter(type1_x, type1_y,s=40, c=‘red‘ ) type2 = axes.scatter(type2_x, type2_y, s=40, c=‘green‘) W1 = 0.8148005405344305 W2 = -0.27263471796762484 B = -3.8392586254518437 x = np.linspace(-4,10,200) y = (-W1/W2)*x+(-B/W2) axes.plot(x,y,‘b‘,lw=3) plt.xlabel(‘x1‘) plt.ylabel(‘x2‘) axes.legend((type1, type2), (‘0‘, ‘1‘),loc=1) plt.show() #0.8148005405344305 -0.27263471796762484 -3.8392586254518437
用高斯核,当C=6,sigma=1时候
高斯核,当c=0.5,sigma=1时候
当C=0.5,sigma=12时候
说明C的大小和sigma的大小对高斯核影响是很大的
自己实现的SVM源码
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。