首先是DATA类
import java.awt.print.Printable; import java.io.File; import java.io.FileNotFoundException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Scanner; public class Data { public Map<List<Double>, Integer> getTrainData() { Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>(); try { Scanner in=new Scanner(new File("G://download//testSet.txt")); while(in.hasNextLine()) { String str =in.nextLine(); String []strs=str.trim().split(" "); List<Double> pointTmp=new ArrayList<>(); for(int i=0;i<strs.length-1;i++) pointTmp.add(Double.parseDouble(strs[i])); data.put(pointTmp, Integer.parseInt(strs[strs.length-1])); } } catch (FileNotFoundException e) { // TODO: handle exception e.printStackTrace(); } return data; } public static void main(String[] args) { Data data=new Data(); data.getTrainData(); } }
SVM类:
import java.awt.print.Printable; import java.io.FileNotFoundException; import java.io.ObjectInputStream.GetField; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Map.Entry; public class SVM { private List<ArrayList<Double>> trainData; private List<Integer> labelTrainData; private double sigma; private double C; private List<Double> alpha; private double b; private List<Double> E; private int N; private int dim; private double tol; private double eta; private double eps; private double eps2; public boolean satisfyKkt(int id) { double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x) if(Math.abs(this.alpha.get(id))<=this.eps) { if(ypgx-1<-this.tol) return false; } else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps) { if(ypgx-1>this.tol) return false; } else { if(Math.abs(ypgx-1)>this.tol) return false; } return true; } public void updateE() { for(int i=0;i<this.N;i++) { double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i); this.E.set(i, Ei); } } public double kernelLinear(List<Double> X,List<Double> Y) { //linear kernel function int len=Y.size(); double s=0; for(int i=0;i<len;i++) s+=X.get(i)*Y.get(i); return s; } public double kernelRBF(List<Double> X,List<Double> Y) { //gauss kernel function int len=Y.size(); double s=0; for(int i=0;i<len;i++) s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i)); s=Math.exp(-s/(2*Math.pow(this.sigma, 2))); return s; } public double getGx(List<Double> X) { //calculate wx+b value double s=0; for(int i=0;i<this.N;i++) { //for debug double debug1=kernelRBF(X, this.trainData.get(i)); double debug2=this.alpha.get(i); s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i)); } s+=this.b; return s; } public int update(int x1,int x2) { double low=0; double high=0; if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2)) { low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C); high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1)); } else { low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1)); high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C); } double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta; double newAlpha1=0; if(newAlpha2>high) newAlpha2=high; else if(newAlpha2<low) newAlpha2=low; newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2); if(Math.abs(newAlpha1)<=this.eps) newAlpha1=0; if(Math.abs(newAlpha2)<=this.eps) newAlpha2=0; if(Math.abs(newAlpha1-this.C)<=this.eps) newAlpha1=this.C; if(Math.abs(newAlpha2-this.C)<=this.eps) newAlpha2=this.C; if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2) return 0; if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2) return 0; double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b; double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b; if(newAlpha1>0&&newAlpha1<this.C) this.b=b1; else if(newAlpha2>0&&newAlpha2<this.C) this.b=b2; else this.b=(b1+b2)/2; this.alpha.set(x1,newAlpha1); this.alpha.set(x2,newAlpha2); updateE(); return 1; } public int selectAlpha2(int x1) { int x2=-1; double maxDiff=-1; //first select x2 from 0<a<c to max(E(x1)-E(x2)) for(int i=0;i<this.N;++i) { if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue; double diff=Math.abs(this.E.get(x1)-this.E.get(i)); if(diff>maxDiff) { maxDiff=diff; x2=i; } } //second calculate eta (eta!=0) if(x2!=-1) { this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2)); if(eta!=0) return x2; } //third if cannot find in the whole train set for(int i=0;i<this.N;i++) { if(i==x1) continue; this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i)); if(Math.abs(this.eta)>this.eps) return i; } return -1; } public void SMO() { //to solve alpha int numChanged=0; int cnt=0; while(true) { cnt++; System.out.println(cnt); numChanged=0; for(int x1=0;x1<this.N;++x1) { if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue; if(!satisfyKkt(x1)) { int x2=selectAlpha2(x1); if(x2==-1) continue; numChanged+=update(x1, x2); } } if(numChanged==0) { for(int x1=0;x1<this.N;++x1) { if(!satisfyKkt(x1)) { int x2=selectAlpha2(x1); if(x2==-1) continue; update(x1, x2); numChanged++; } } } if(numChanged==0) break; } } public SVM() { //load train data Data data=new Data(); Map<List<Double>, Integer> Datas=data.getTrainData(); int totalData=Datas.size(); this.trainData=new ArrayList<ArrayList<Double>>(); this.labelTrainData=new ArrayList<Integer>(); this.alpha=new ArrayList<Double>(); this.E=new ArrayList<Double>(); int i=0; for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet()) { this.trainData.add((ArrayList<Double>) entry.getKey()); this.labelTrainData.add(entry.getValue()); this.alpha.add(0.0); this.E.add(0.0-this.labelTrainData.get(i)); i++; } this.N=this.labelTrainData.size(); this.dim=this.trainData.get(0).size(); this.sigma=12;//sigma=1 this.C=0.5;//c=6 this.b=0.0; this.tol=0.001; this.eta=0; this.eps=0.0000001; this.eps2=0.00001; } public double getB() { //get b value return this.b; } public double[] getLinearW() { double []w=new double[this.N]; for(int i=0;i<this.N;i++) { for(int j=0;j<this.dim;j++) { w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j); } } return w; } public int predict(List<Double> x) { int ans=1; double sum=0; for(int i=0;i<this.N;i++) { sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i)); } sum+=b; if(sum>0) ans=1; else ans=-1; return ans; } public static void main(String[] args) throws FileNotFoundException { SVM s=new SVM(); s.SMO(); PrintWriter out=new PrintWriter("g://download//resultpoints.txt"); for(int i=0;i<s.N;i++) { out.write((s.trainData.get(i).get(0)).toString()); out.write(" "); out.write((s.trainData.get(i).get(1)).toString()); out.write(" "); out.write(Integer.toString(s.predict(s.trainData.get(i)))); out.write(" "); } out.close(); //if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction double w[]=s.getLinearW(); System.out.println(w[0]+" "+w[1]+" "+s.b+"======"); } }
用线性核函数实现的SVM的到的分类结果
画图,是用python代码
from numpy import * import matplotlib import matplotlib.pyplot as plt import numpy as np with open("g://download/myresult.txt") as f1: data=f1.readlines(); plt.figure(figsize=(8, 5), dpi=80) axes = plt.subplot(111) type1_x = [] type1_y = [] type2_x = [] type2_y = [] for line in data: x=line.strip().split(' '); x1=float(x[0]) x2=float(x[1]) x3=int(x[2]) if x3==1: type1_x.append(x1) type1_y.append(x2) else: type2_x.append(x1) type2_y.append(x2) type1 = axes.scatter(type1_x, type1_y,s=40, c='red' ) type2 = axes.scatter(type2_x, type2_y, s=40, c='green') W1 = 0.8148005405344305 W2 = -0.27263471796762484 B = -3.8392586254518437 x = np.linspace(-4,10,200) y = (-W1/W2)*x+(-B/W2) axes.plot(x,y,'b',lw=3) plt.xlabel('x1') plt.ylabel('x2') axes.legend((type1, type2), ('0', '1'),loc=1) plt.show() #0.8148005405344305 -0.27263471796762484 -3.8392586254518437
用高斯核,当C=6,sigma=1时候
高斯核,当c=0.5,sigma=1时候
当C=0.5,sigma=12时候
说明C的大小和sigma的大小对高斯核影响是很大的
sigma是高斯核函数的参数