zoukankan      html  css  js  c++  java
  • 自己实现的SVM源码

    首先是DATA类

    import java.awt.print.Printable;
    import java.io.File;
    import java.io.FileNotFoundException;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import java.util.Scanner;
    
    public class Data {
    public Map<List<Double>, Integer> getTrainData() {
    	Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>();
    	
    	try {
    		Scanner in=new Scanner(new File("G://download//testSet.txt"));
    		while(in.hasNextLine())
    		{
    			String str =in.nextLine();
    			String []strs=str.trim().split("	");
    			List<Double> pointTmp=new ArrayList<>();
    			for(int i=0;i<strs.length-1;i++)
    				pointTmp.add(Double.parseDouble(strs[i]));
    			data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));
    		}
    	} catch (FileNotFoundException e) {
    		// TODO: handle exception
    		e.printStackTrace();
    	}
    	
    	return data;
    }
    
    public static void main(String[] args)
    {
    	Data data=new Data();
    	data.getTrainData();
    }
    }
    

      SVM类:

    import java.awt.print.Printable;
    import java.io.FileNotFoundException;
    import java.io.ObjectInputStream.GetField;
    import java.io.PrintWriter;
    import java.util.ArrayList;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map;
    import java.util.Random;
    import java.util.Map.Entry;
    
    public class SVM {
    	private List<ArrayList<Double>> trainData;
    	private List<Integer> labelTrainData;
    	private double sigma;
    	private double C;
    	private List<Double> alpha;
    	private double b;
    	private List<Double> E;
    	private int N;
    	private int dim;
    	private double tol;
    	private double eta;
    	private double eps;
    	private double eps2;
    	
    	public boolean satisfyKkt(int id)
    	{
    		double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)
    		if(Math.abs(this.alpha.get(id))<=this.eps)
    		{
    			if(ypgx-1<-this.tol) return false;
    		}
    		else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)
    		{
    			if(ypgx-1>this.tol) return false;
    		}
    		else {
    			if(Math.abs(ypgx-1)>this.tol) return false;
    		}
    		return true;
    	}
    	
    	public void updateE() {
    		
    		for(int i=0;i<this.N;i++)
    		{
    			double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);
    			this.E.set(i, Ei);
    		}
    	}
    	
    	public double kernelLinear(List<Double> X,List<Double> Y) {
    		//linear kernel function
    		int len=Y.size();
    		double s=0;
    		for(int i=0;i<len;i++)
    			s+=X.get(i)*Y.get(i);
    		return s;
    	}
    	
    	
    	
    	public double kernelRBF(List<Double> X,List<Double> Y)
    	{
    		//gauss kernel function
    		
    		int len=Y.size();
    		double s=0;
    		for(int i=0;i<len;i++)
    			s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));
    		s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));
    		return s;
    	}
    	
    	
    	public double getGx(List<Double> X)
    	{
    		//calculate wx+b value
    		double s=0;
    		for(int i=0;i<this.N;i++)
    		{
    			//for debug
    			double debug1=kernelRBF(X, this.trainData.get(i));
    			double debug2=this.alpha.get(i);
    			
    			s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));
    		}
    		s+=this.b;
    		return s;
    	}
    	
    	public int update(int x1,int x2)
    	{
    		double low=0;
    		double high=0;
    		if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))
    		{
    			low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);
    			high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));
    		}
    		else
    		{
    			low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));
    			high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);
    		}
    		double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;
    		double newAlpha1=0;
    		
    		if(newAlpha2>high) newAlpha2=high;
    		else if(newAlpha2<low) newAlpha2=low;
    		newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2);
    		
    		if(Math.abs(newAlpha1)<=this.eps)
    			newAlpha1=0;
    		if(Math.abs(newAlpha2)<=this.eps)
    			newAlpha2=0;
    		if(Math.abs(newAlpha1-this.C)<=this.eps)
    			newAlpha1=this.C;
    		if(Math.abs(newAlpha2-this.C)<=this.eps)
    			newAlpha2=this.C;
    		if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)
    			return 0;
    		if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)
    			return 0;
    		
    		double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;
    		double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b;
    		
    		if(newAlpha1>0&&newAlpha1<this.C)
    			this.b=b1;
    		else if(newAlpha2>0&&newAlpha2<this.C)
    			this.b=b2;
    		else
    			this.b=(b1+b2)/2;
    		
    		this.alpha.set(x1,newAlpha1);
    		this.alpha.set(x2,newAlpha2);
    		updateE();
    		return 1;
    	}
    	public int selectAlpha2(int x1) {
    		
    		int x2=-1;
    		double maxDiff=-1;
    		//first select x2 from 0<a<c to max(E(x1)-E(x2))
    		
    		for(int i=0;i<this.N;++i)
    		{
    			if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;
    			double diff=Math.abs(this.E.get(x1)-this.E.get(i));
    			if(diff>maxDiff)
    			{
    				maxDiff=diff;
    				x2=i;
    			}
    		}
    		
    		//second calculate eta (eta!=0)
    		if(x2!=-1)
    		{
    			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));
    			if(eta!=0) return x2;
    		}
    		
    		//third if cannot find in the whole train set
    		for(int i=0;i<this.N;i++)
    		{
    			if(i==x1) continue;
    			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));
    			if(Math.abs(this.eta)>this.eps) return i;
    		}
    		return -1;
    		
    		
    	}
    	
    	public void SMO() {
    		//to solve alpha
    		int numChanged=0;
    		int cnt=0;
    		while(true)
    		{
    			cnt++;
    			System.out.println(cnt);
    			
    			numChanged=0;
    			for(int x1=0;x1<this.N;++x1)
    			{
    				if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;
    				if(!satisfyKkt(x1))
    				{
    					int x2=selectAlpha2(x1);
    					if(x2==-1) continue;
    					numChanged+=update(x1, x2);
    				}
    			}
    			if(numChanged==0)
    			{
    				for(int x1=0;x1<this.N;++x1)
    				{
    					if(!satisfyKkt(x1))
    					{
    						int x2=selectAlpha2(x1);
    						if(x2==-1) continue;
    						update(x1, x2);
    						numChanged++;
    					}
    				}
    			}
    			if(numChanged==0)
    				break;				
    		}
    	}
    	
    	public SVM() {
    		//load train data
    		
    		Data data=new Data();
    		Map<List<Double>, Integer> Datas=data.getTrainData();
    		int totalData=Datas.size();
    		this.trainData=new ArrayList<ArrayList<Double>>();
    		this.labelTrainData=new ArrayList<Integer>();
    		this.alpha=new ArrayList<Double>();
    		this.E=new ArrayList<Double>();
    		
    		int i=0;
    		for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet())
    		{
    			this.trainData.add((ArrayList<Double>) entry.getKey());
    			this.labelTrainData.add(entry.getValue());
    			this.alpha.add(0.0);
    			this.E.add(0.0-this.labelTrainData.get(i));
    			i++;
    		}
    		this.N=this.labelTrainData.size();
    		this.dim=this.trainData.get(0).size();
    		
    		this.sigma=12;//sigma=1
    		this.C=0.5;//c=6
    		this.b=0.0;
    		this.tol=0.001;
    		this.eta=0;
    		this.eps=0.0000001;
    		this.eps2=0.00001;
    	}
    	
    	public double getB() {
    		//get b value
    		return this.b;
    	}
    	public double[] getLinearW() {
    		double []w=new double[this.N];
    		for(int i=0;i<this.N;i++)
    		{
    			for(int j=0;j<this.dim;j++)
    			{
    				w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j);
    			}
    		}
    		return w;
    	}
    	
    	public int predict(List<Double> x)
    	{
    		int ans=1;
    		double sum=0;
    		for(int i=0;i<this.N;i++)
    		{
    			sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i));
    		}
    		sum+=b;
    		if(sum>0)
    			ans=1;
    		else
    			ans=-1;
    		
    		return ans;
    	}
    	public static void main(String[] args) throws FileNotFoundException {
    		
    		SVM s=new SVM();
    		s.SMO();
    		PrintWriter out=new PrintWriter("g://download//resultpoints.txt");
    		for(int i=0;i<s.N;i++)
    		{
    			out.write((s.trainData.get(i).get(0)).toString());
    			out.write("	");
    			out.write((s.trainData.get(i).get(1)).toString());
    			out.write("	");
    			out.write(Integer.toString(s.predict(s.trainData.get(i))));
    			out.write("
    ");
    		}
    		out.close();
    		//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction
    		double w[]=s.getLinearW();
    		System.out.println(w[0]+" "+w[1]+" "+s.b+"======");
    	}
    
    }
    

      

    用线性核函数实现的SVM的到的分类结果

     画图,是用python代码

    from numpy import *  
    import matplotlib  
    import matplotlib.pyplot as plt  
    import numpy as np
    
    with open("g://download/myresult.txt") as f1:
        data=f1.readlines();
        
        plt.figure(figsize=(8, 5), dpi=80)   
        axes = plt.subplot(111)   
        type1_x = []  
        type1_y = []  
        type2_x = []  
        type2_y = [] 
        for line in data:
            x=line.strip().split('	');
            x1=float(x[0])
            x2=float(x[1])
            x3=int(x[2])
            
            if x3==1:
                type1_x.append(x1)
                type1_y.append(x2)
            else:
                type2_x.append(x1)
                type2_y.append(x2)
            
    
        type1 = axes.scatter(type1_x, type1_y,s=40, c='red' )   
        type2 = axes.scatter(type2_x, type2_y, s=40, c='green')  
        
        W1 = 0.8148005405344305  
        W2 = -0.27263471796762484  
        B = -3.8392586254518437  
        x = np.linspace(-4,10,200)  
        y = (-W1/W2)*x+(-B/W2)  
        axes.plot(x,y,'b',lw=3)  
       
        plt.xlabel('x1')   
        plt.ylabel('x2')   
        
        axes.legend((type1, type2), ('0', '1'),loc=1)   
        plt.show()  
    
    
    #0.8148005405344305 -0.27263471796762484 -3.8392586254518437
    

      用高斯核,当C=6,sigma=1时候

    高斯核,当c=0.5,sigma=1时候

    当C=0.5,sigma=12时候

    说明C的大小和sigma的大小对高斯核影响是很大的

     sigma是高斯核函数的参数

  • 相关阅读:
    Centos7 tomcat 启动权限
    Tomcat下post请求大小设置
    postgres安装时locale的选择
    flink 1.11.2 学习笔记(1)-wordCount
    prometheus学习笔记(3)-使用exporter监控mysql
    prometheus学习笔记(2)-利用java client写入数据
    mock测试及jacoco覆盖率
    shading-jdbc 4.1.1 + tk.mybatis + pagehelper 1.3.x +spring boot 2.x 使用注意事项
    prometheus学习笔记(1)-mac单机版环境搭建
    redis数据类型HyperLogLog的使用
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/6275112.html
Copyright © 2011-2022 走看看