zoukankan      html  css  js  c++  java
  • Fisher线性判别

    Fisher线性判别(Fisher Linear Discrimination,FLD),也称线性判别式分析(Linear Discriminant Analysis, LDA)。FLD是基于样本类别进行整体特征提取的有效方法。它在使用PCA方法进行降维的基础上考虑到训练样本的类间信息。FLD的基本原理就是找到一个最合适的投影轴,使各类样本在该轴上投影之间的距离尽可能远,而每一类内的样本的投影尽可能紧凑,从而使分类效果达到最佳,即在最大化类间距离的同时最小化类内距离。FLD方法在进行图像整体特征提取方面有着广泛的应用。

    在应用统计方法解决模式识别问题时,经常会遇到所谓的“维数灾难”的问题,在低维空间里适用的方法在高维空间里可能完全不适用。因此压缩特征空间的维数有时是很重要的。Fisher方法实际上涉及维数压缩的问题。如果把多维特征空间的点投影到一条直线上,就能把特征空间压缩成一维,这个在数学上是很容易办到的。但是,在高维空间里很容易分开的样品,把它们投影到任意一根直线上,有可能不同类别的样品就混在一起,无法区分,如图1(a)所示投影到xl或x2轴无法区分。若把直线绕原点转动一下,就有可能找到一个方向,样品投影到这个方向的直线上,各类样品就能很好地分开,如图1(b)所示。因此直线方向的选择很重要。一般地,总能够找到一个最好的方向,使样品投影到这个方向的直线上很容易分开。如何找到这个最好的直线方向以及如何实现向最好方向投影的变换,这正是Fisher算法要解决的基本问题,这个投影变换恰是我们所寻求的解向量w*。



                                                                     图1 Fisher线性判别示意图

    样品训练集以及待测样品的特征总数目为n。为了找到最佳投影方向,需要计算出各类样品均值,样品类内离散度矩阵Si和总类间离散度矩阵Sw,样品类间离散度矩阵Sb,根据Fisher准则,找到最佳投影向量,将训练集内所有样品进行投影,投影到一维Y空间,由于Y空间是一维的,则需要求出Y空间的划分边界点,找到边界点后,就可以对待测样品进行一维Y空间的投影,判断它的投影点与分界点的关系,将其归类。具体方法如下。



    /******************************************************************
    *   函数名称:Fisher_2Classes(int Class0, int Class1)
    *   函数类型:int 
    *   参数说明:Class0,Class1:0~9中的任意两个类别
    *   函数功能:两类Fisher分类器,返回Class0,Class1中的一个
    ******************************************************************/
    int Classification::Fisher_2Classes(int Class0, int Class1)
    {
    	double Xmeans[2][25];//两类的均值
    	double S[2][25][25];//样品类内离散度矩阵
    	double Sw[25][25];//总类间离散度矩阵
    	double Sw_[25][25];//Sw的逆矩阵
    	double W[25];//解向量w*
    	double difXmeans[25];//均值差
    	double X[25];//未知样品
    	double m0,m1;//类样品均值
    	double y0;//阈值y0
    	int i,j,k;
    
    	for(i=0;i<2;i++)
    		for(j=0;j<25;j++)
    			Xmeans[i][j]=0;
    	int num0,num1;		//两类样品的个数
    	//两类样品特征
    	double mode0[200][25],mode1[200][25];
    	//两类样品个数
    	num0=40;//pattern[Class0].number;
    	num1=40;//pattern[Class1].number;
    	for(i=0;i<num0;i++)
    	{
    		for(j=0;j<25;j++)
    		{
    			Xmeans[0][j]+=pattern[Class0].feature[i][j];
    			mode0[i][j]=pattern[Class0].feature[i][j];
    		}
    	}
    
    	for(i=0;i<num1;i++)
    	{
    		for(j=0;j<25;j++)
    		{
    			Xmeans[1][j]+=pattern[Class1].feature[i][j];	
    			mode1[i][j]=pattern[Class1].feature[i][j];
    		}
    	}
    	//求得两个样品均值向量
    	for(i=0;i<25;i++)	
    	{
    		Xmeans[0][i]/=(double)num0;
    		Xmeans[1][i]/=(double)num1;
    	}
    	//求两类样品类内离散度矩阵
    	for(i=0;i<25;i++)
    	for(j=0;j<25;j++)
    	{
    		double s0=0.0,s1=0.0;
    		for(k=0;k<num0;k++)
    			s0=s0+(mode0[k][i]-Xmeans[0][i])*(mode0[k][j]-Xmeans[0][j]);
    		s0=s0/(double)(num0-1);
    		S[0][i][j]=s0;//第一类
    		for(k=0;k<num1;k++)
    			s1=s1+(mode1[k][i]-Xmeans[1][i])*(mode1[k][j]-Xmeans[1][j]);
    		s1=s1/(double)(num1-1);
    		S[1][i][j]=s1;//第二类		
    	}
    	//总类间离散度矩阵
    	for(i=0;i<25;i++)
    	for(j=0;j<25;j++)
    	{
    		Sw[i][j]=S[0][i][j]+S[1][i][j];
    	}
    	//Sw的逆矩阵
    	for(i=0;i<25;i++)
    		for(j=0;j<25;j++)
    			Sw_[i][j]=Sw[i][j];	
    	double(*p)[25]=Sw_;	
    	brinv(*p,25);		//Sw的逆矩阵Sw_
    	//计算w*  w*=Sw_×(Xmeans0-Xmeans1)
    	for(i=0;i<25;i++)
    		difXmeans[i]=Xmeans[0][i]-Xmeans[1][i];
    	for(i=0;i<25;i++)
    		W[i]=0.0;
    	brmul(Sw_,difXmeans,25,W);//计算出W*
    	
    	//各类样品均值
    	m0=0.0;
    	m1=0.0;
    	for(i=0;i<num0;i++)
    	{
    		m0+=brmul(W,mode0[i],25);
    	}
    	for(i=0;i<num1;i++)
    	{
    		m1+=brmul(W,mode1[i],25);
    	}
    	m0/=(double)num0;
    	m1/=(double)num1;
    	y0=(num0*m0+num1*m1)/(num0+num1);//阈值y0
    	
    	//对于任意的手写数字X
    	for(i=0;i<25;i++)
    		X[i]=testsample[i];
    	double y;//X在w*上的投影点
    	y=brmul(W,X,25);
    	if (y>=y0) 
    		return Class0;
    	else
    		return Class1;
    }
    
    /******************************************************************
    *   函数名称:Fisher()
    *   函数类型:int 
    *   函数功能:Fisher分类器,返回手写数字的类别
    ******************************************************************/
    int Classification::Fisher()
    {
    	int i,j,number,maxval,num[10];
    	for(i=0;i<10;i++)
    		num[i]=0;
    	for(i=0;i<10;i++)
    		for(j=0;j<i;j++)
    			num[Fisher_2Classes(i,j)]++;
    	maxval=num[0];
    	number=0;
    	for(i=1;i<10;i++)
    	{
    		if(num[i]>maxval)
    		{
    			maxval=num[i];
    			number=i;
    		}
    	}
    	return number;
    }

    /******************************************************************
    *函数名称:brmul(double a[],double b[][25],int n,double c[])
    *函数类型:void
    *参数说明:a-双精度实型数组,存放A的元素。
    *          b-双精度实型数组,存放B的元素。
    *          n-整型变量,矩阵A的列数,也是矩阵B的行数。
    *          c-双精度实型数组,存放乘积矩阵C=AB的元素。
    *函数功能:求矩阵A与B的乘积矩阵C=AB。
    ******************************************************************/
    void brmul(double a[],double b[][25],int n,double c[])//矩阵乘法,c=a*b
    { 
    	for(int i=0;i<n;i++)
    	{
    		for(int j=0;j<n;j++)
    			c[i]+=a[j]*b[j][i];
    	}
    	return;
    }
    

    /******************************************************************
    *函数名称:brinv(double a[],int n)
    *函数类型:void
    *参数说明:a--双精度实型数组,n--整型变量,方阵A的阶数
    *函数功能:用全选主元Gauss-Jordan消去法求n阶实矩阵A的逆矩阵
    ******************************************************************/
    void brinv(double a[],int n)
    { 
    	int *is,*js,i,j,k,l,u,v;
        double d,p;
        is=new int[n];
        js=new int[n];
        for (k=0; k<=n-1; k++)
    	{ 
    		d=0.0;
            for (i=k; i<=n-1; i++)
    			for (j=k; j<=n-1; j++)
    			{ 
    				l=i*n+j; p=fabs(a[l]);
    				if (p>d) 
    				{ 
    					d=p; is[k]=i; js[k]=j;
    				}
    			}
    			if (d+1.0==1.0)
    			{ 
    				free(is); free(js); printf("err**not inv
    ");
    				return;
    			}
    			if (is[k]!=k)
    				for (j=0; j<=n-1; j++)
    				{ 
    					u=k*n+j; v=is[k]*n+j;
    					p=a[u]; a[u]=a[v]; a[v]=p;
    				}
    				if (js[k]!=k)
    					for (i=0; i<=n-1; i++)
    					{ 
    						u=i*n+k; v=i*n+js[k];
    						p=a[u]; a[u]=a[v]; a[v]=p;
    					}
    					l=k*n+k;
    					a[l]=1.0/a[l];
    					for (j=0; j<=n-1; j++)
    						if (j!=k)
    						{
    							u=k*n+j; a[u]=a[u]*a[l];
    						}
    						for (i=0; i<=n-1; i++)
    							if (i!=k)
    								for (j=0; j<=n-1; j++)
    									if (j!=k)
    									{ 
    										u=i*n+j;
    										a[u]=a[u]-a[i*n+k]*a[k*n+j];
    									}
    									for (i=0; i<=n-1; i++)
    										if (i!=k)
    										{
    											u=i*n+k; a[u]=-a[u]*a[l];
    										}
    	}
        for (k=n-1; k>=0; k--)
    	{ 
    		if (js[k]!=k)
    			for (j=0; j<=n-1; j++)
                { 
    				u=k*n+j; v=js[k]*n+j;
    				p=a[u]; a[u]=a[v]; a[v]=p;
                }
    			if (is[k]!=k)
    				for (i=0; i<=n-1; i++)
    				{ 
    					u=i*n+k; v=i*n+is[k];
    					p=a[u]; a[u]=a[v]; a[v]=p;
    				}
    	}
        delete is; 
    	delete js;
    }


    版权声明:

  • 相关阅读:
    linux freopen函数
    进程的环境变量environ
    ls -l 和du 的区别
    Python时间,日期,时间戳之间转换
    Web 服务器压力测试实例详解
    装numpy 环境:python3.4+ windows7 +64位系统
    在Windows Python3.4 上安装NumPy、Matplotlib、SciPy和IPython
    apache 自带的ab.exe 测试网站的并发量(网站压力测试)
    成员如何关注微信企业号?
    微信企业号通讯录有什么用?
  • 原文地址:https://www.cnblogs.com/walccott/p/4957048.html
Copyright © 2011-2022 走看看