zoukankan      html  css  js  c++  java
  • Fisher算法+两类问题



    一、Fisher算法

    在这里插入图片描述

    二、蠓的分类问题:

    两种蠓AfApf已由生物学家根据它们的触角翼长加以区分(Af是能传播花粉的益虫,Apf是会传播疾病的害虫),两个矩阵中分别给出了6只Apf 和9只Af蠓的触角长(对应于矩阵的第1列)和翼长(对应于矩阵的第2列)的数据(See next slide)。根据触角长和翼长这两个特征来识别一个样本是Af还是Apf是重要的。

    (1)试给出该问题的Fisher分类器;

    (2)有三个待识别的模式样本,它们分别是(1.24,1.80),(1.28,1.84),( 1.40,2.04),试问这三个样本属于哪一种蠓。

    数据集:

    APF = np.array([
        [1.14,1.78],[1.18,1.96],
        [1.20,1.86],[1.26,2.00],
        [1.30,2.00],[1.28,1.96]
    ])
    
    AF = np.array([
        [1.24,1.72],[1.36,1.74],
        [1.38,1.64],[1.38,1.82],
        [1.38,1.90],[1.40,1.70],
        [1.48,1.82],[1.54,2.08],
        [1.56,1.78]
    ])
    

    三、代码实现:

    Fisher算法关键在于求出权向量W_和阈值W*,然后求出待测数据的映射y_test,最后与W*阈值作比较。

    import numpy as np
    
    
    APF = np.array([
        [1.14,1.78],[1.18,1.96],
        [1.20,1.86],[1.26,2.00],
        [1.30,2.00],[1.28,1.96]
    ])
    
    AF = np.array([
        [1.24,1.72],[1.36,1.74],
        [1.38,1.64],[1.38,1.82],
        [1.38,1.90],[1.40,1.70],
        [1.48,1.82],[1.54,2.08],
        [1.56,1.78]
    ])
    
    #获取样本均值
    def getAvg(x):
        return np.mean(x, axis=0)
    
    #求两类样本类内离散度矩阵Si
    def getSi(x, x_mean):
        x_mean = x_mean.reshape(x.shape[1],1)
        Si = np.zeros((x.shape[1],x.shape[1]))
        for xi in x:
            temp_xi = xi.copy().reshape(x.shape[1],1)
            temp = (temp_xi-x_mean)
            Si = Si + np.dot(temp, temp.T)
        return Si
    
    # 求权向量W_
    def getW(x1_mean,x2_mean,Sw):
        return np.dot(np.linalg.inv(Sw),(x1_mean-x2_mean))
    
    # 获取分类阈值w0和权向量W_
    def get_w0(x1, x2):
        x1_mean = getAvg(x1)
        x2_mean = getAvg(x2)
        S1 = getSi(APF, x1_mean)
        S2 = getSi(AF, x2_mean)
        Sw = S1+S2
        W_ = getW(x1_mean,x2_mean,Sw)
    
        #获取投影点
        y1 = np.dot(x1, W_)
        y2 = np.dot(x2, W_)
    
        #求各类样本均值yi_mean
        y1_mean = np.mean(y1)
        y2_mean = np.mean(y2)
    
        #选取分类阈值w0
        w0 = (y1_mean + y2_mean) / 2
    
        return w0, W_
    
    
    def Fisher(x1, x1_label, x2, x2_label, x_test):
        w0, W_ = get_w0(x1,x2)
        y_test = np.dot(x_test, W_)
        if y_test > w0:
            print('测试样本属于', x1_label)
        elif y_test <w0:
            print('测试样本属于',x2_label)
        else:
            print('测试样本可能属于%s,也可嫩属于%s'%x1_label%x2_label)
    
    
    x_tests = np.array([
        [1.24,1.80],[1.28,1.84],[1.40,2.04]
    ])
    
    i = 1
    for x_test in x_tests:
        print('第%d个'%i,end='')
        i += 1
        Fisher(APF,'蠓APF',AF,'蠓AF',x_test)
    

    预测结果如下:

    第1个测试样本属于 蠓APF
    第2个测试样本属于 蠓APF
    第3个测试样本属于 蠓APF
    
  • 相关阅读:
    ajax面试题
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
  • 原文地址:https://www.cnblogs.com/theory/p/11884314.html
Copyright © 2011-2022 走看看