zoukankan      html  css  js  c++  java
  • 奇异值分解(SVD和TruncatedSVD)

    1.两者概念理解

    2.SVD的使用

    np.linalg.svd(a, full_matrices=True, compute_uv=True)

    参数:

    • a : 是一个形如(M,N)矩阵

    • full_matrices:的取值是为0或者1,默认值为1,这时u的大小为(M,M),v的大小为(N,N) 。否则u的大小为(M,K),v的大小为(K,N) ,K=min(M,N)。

    • compute_uv:取值是为0或者1,默认值为1,表示计算u,s,v。为0的时候只计算s。

    返回值:

    • 总共有三个返回值u,s,v,其中s是对矩阵a的奇异值分解。s除了对角元素不为0,其他元素都为0,并且对角元素从大到小排列。s中有n个奇异值,一般排在后面的比较接近0,所以仅保留比较大的r个奇异值。

    Corpus(window_size=1):

    I like learning.

    I like NLP.

    I enjoy flying.

     1 import numpy as np
     2 
     3 import matplotlib.pyplot as plt
     4 plt.style.use('ggplot')
     5 
     6 
     7 words =['I', 'like', 'enjoy',
     8         'deep', 'learning', 'NLP', 'flying','.']
     9 
    10 X = np.array([[0,2,1,0,0,0,0,0],        #X是共现矩阵
    11               [2,0,0,1,0,1,0,0],
    12               [1,0,0,0,0,0,1,0],
    13               [0,1,0,0,1,0,0,0],
    14               [0,0,0,1,0,0,0,1],
    15               [0,1,0,0,0,0,0,1],
    16               [0,0,1,0,0,0,0,1],
    17               [0,0,0,0,1,1,1,0]])
    18 U, s, Vh = np.linalg.svd(X, full_matrices=False)
    19 print(U.shape)                              #(8, 8)
    20 print(s.shape)                              #(8,)
    21 print(Vh.shape)                             #(8, 8)
    22 print(np.allclose(X, np.dot(U * s, Vh)))    #True,allclose比较两个array是不是每一元素都相等,默认在1e-05的误差范围内
    23 
    24 plt.xlim([-0.8, 0.2])
    25 plt.ylim([-0.8, 0.8])
    26 for i in range(len(words)):
    27     plt.text(U[i,0], U[i,1], words[i])

    3.TruncatedSVD的使用

    TruncatedSVD 的创建必须指定所需的特征数或所要选择的成分数,比如 2。一旦创建完成,你就可以通过调用 fit() 函数来拟合该变换,然后再通过调用 transform() 函数将其应用于原始矩阵。

    1 from sklearn.decomposition import TruncatedSVD
    2 svd = TruncatedSVD(n_components=2)
    3 X_reduced = svd.fit_transform(X)               #X是上面的共现矩阵,两步合一步
    4 print(X_reduced)

    [[ 1.44515015 -1.53425886]
    [ 1.63902195 1.68761941]
    [ 0.70661477 0.73388691]
    [ 0.78757738 -0.66397017]
    [ 0.53253583 0.09065737]
    [ 0.8413365 -0.78737543]
    [ 0.50317243 -0.4312723 ]
    [ 0.68076383 0.42116725]]

    手动计算得到的值与上面调用函数的结果一致,但某些值的符号不一样。由于所涉及的计算的性质以及所用的基础库和方法的差异,可以预见在符号方面会存在一些不稳定性:

    1 Vh = Vh[:2, :]
    2 print(X.dot(Vh.T))

    [[-1.44515015 -1.53425886]
    [-1.63902195 1.68761941]
    [-0.70661477 0.73388691]
    [-0.78757738 -0.66397017]
    [-0.53253583 0.09065737]
    [-0.8413365 -0.78737543]
    [-0.50317243 -0.4312723 ]
    [-0.68076383 0.42116725]]

    参考:https://zhuanlan.zhihu.com/p/134512367

  • 相关阅读:
    企业微信的部门长度问题
    MVC中view与controller传json数据
    jQuery.extend()、jQuery.fn.extend()扩展方法示例详解
    程序员成长思维:把自己当做产品来发展
    发展你的兴趣,而不是跟随你的兴趣
    领导力:不要做个“好人”
    Nginx性能优化
    【.NET与树莓派】上手前的一些准备工作
    php curl时遇到Can't load the certificate "..." and its private key: OSStatus -25299的问题
    ASCII码字符对照表
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13407279.html
Copyright © 2011-2022 走看看