zoukankan      html  css  js  c++  java
  • LDA(线性判别分析)【python实现】

    原理

    求解最佳投影方向,使得同类投影点尽可能的进,异类投影点尽可能的远
    同类投影点距离用同类样本协方差矩阵表示

    [omega^T Sigma_i omega quad {第i类样本协方差} ]

    异类投影点距离

    [||omega^Tmu_0 - omega^Tmu_1||_2^2 ]

    (mu_i \, {第i类样本均值})
    优化函数

    [J(omega) = frac{||omega^Tmu_0 - omega^T mu_1||_2^2}{omega^T (Sigma_0 - Sigma_1) omega} ]

    求上述函数极大值,解出(omega)

    定义类间、类内散度矩阵

    1.类间散度矩阵

    [S_w = Sigma_0 + Sigma_1 = sum_{x in X0}(x - omega)(x - omega)^T + sum_{x in X1}(x - omega)(x - omega)^T ]

    2.类内散度矩阵

    [S_b = (mu_0 - mu_1)(mu_0 - mu_1)^T ]

    [J(omega) = frac{omega^T S_b omega}{omega^T S_w omega} ]

    利用拉格朗日乘数法,可得

    [omega = S_w^{-1}(mu_0 - mu_1) ]

    python程序

    import numpy as np 
    import matplotlib.pyplot as plt 
    M = 2 #属性个数
    N = 50#二分类。每类样本N个
    #随机生成两个属性的N个第一类样本
    feature11 = np.random.randint(0, 7, size = N)
    feature12 = np.random.randint(0, 7, size= N)
    temp_X1 = np.row_stack((feature11, feature12))
    X1 = np.mat(temp_X1)
    #随机生成两个属性的N个第二类样本
    feature21 = np.random.randint(5,11, size= N)
    feature22 = np.random.randint(7, 14, size= N)
    temp_X2 = np.row_stack((feature21, feature22))
    X2 = np.mat(temp_X2)
    #求投影向量omega
    mu1 = np.mat(np.zeros((2,1)))
    mu2 = np.mat(np.zeros((2,1)))
    X_1t = np.array(X1)
    X_2t = np.array(X2)
    for i in range(M):
        mu1[i, 0] = sum([j for j in X_1t[i,:]])/N
    for i in range(M):
        mu2[i, 0] = sum([j for j in X_2t[i,:]])/N
    #print(mu1, mu2)
    s_w1 = np.mat(np.zeros(M))
    s_w2 = np.mat(np.zeros(M))
    for i in range(N):
        s_w1 = s_w1 + (X1[:, i] - mu1)*(X1[:, i] - mu1).T 
    for i in range(N):
        s_w2 = s_w2 + (X2[:, i] - mu2)*(X2[:, i] - mu2).T 
    s_w = s_w1 + s_w2
    Omega = np.linalg.pinv(s_w)*(mu1 - mu2)
    #print(Omega)
    #画出散点图、投影面
    fig = plt.figure(1)
    plt.scatter(feature11, feature12, marker='+')
    plt.scatter(feature21, feature22, marker='*')
    xx_1 = np.linspace(0,10,num=50)
    yy_1 = Omega[1,0]/Omega[0,0]*xx_1
    plt.plot(xx_1,yy_1,color='r')
    plt.show()
    

    效果

    参考资料

    《机器学习》    周志华老师
    坚持
  • 相关阅读:
    OC
    OC
    核心动画
    核心动画
    核心动画
    数据存储1
    plsql语句基础
    Oracle3连接&子查询&联合查询&分析函数
    oracle2约束添加&表复制&拼接
    Oracle表空间创建及表创建
  • 原文地址:https://www.cnblogs.com/liudianfengmang/p/12822990.html
Copyright © 2011-2022 走看看