zoukankan      html  css  js  c++  java
  • Mixtures of Gaussians and the EM algorithms

    Acknowledgement to Stanford CS229.

    Generative modeling is itself a kind of unsupervised learning task[1]. Given unlabelled data, 

    To estimate the parameters, we can write the likelihood as 

    which is also

    The EM algorithm can solve this pdf estimation iteratively.

    An example is provided here. The data points are drawn from 2 gaussian distributions. 

     1 import numpy as np
     2 import operator
     3 np.random.seed(0)
     4 x0=np.random.normal(0,1,50)
     5 x0=np.concatenate((x0,np.random.normal(2,1,50)),
     6                  axis=0)
     7 
     8 mus0=np.array([
     9     0,1
    10 ])
    11 sigmas0=np.array([
    12     2,2
    13 ])
    14 def gauss(x,mu,sigma):
    15     """
    16 
    17     :param x:
    18     :param mu:
    19     :param sigma:
    20     :return: pdf(x)
    21     """
    22     # if np.abs((x-mu)/sigma)<1e-5:
    23     #     return
    24     # numerator=np.exp(
    25     #     -(x-mu)**2/(2*sigma**2)
    26     # )
    27     numerator=np.exp(
    28         -0.5*((x-mu)/sigma)**2
    29     )
    30     denominator=np.sqrt(2*np.pi*sigma**2)
    31     return numerator/denominator
    32 def e_step(mus=mus0,sigmas=sigmas0,x=x0,priors=np.ones(len(mus0))/len(mus0)):
    33     """
    34 
    35     :param mus: gaussian centers, an array of shape (m,)
    36     :param sigmas: gaussian standard deviations, an array of shape (m,)
    37     :param x: n samples with no labels
    38     :return: m by n array, where m is # classes
    39     """
    40     assert len(mus)==len(sigmas),"mus and sigmas doesn't have the same length"
    41     m=len(mus)
    42     n=len(x)
    43     w=np.zeros(shape=(m,n))
    44     for j in range(m):
    45         for i in range(n):
    46             w[j][i]=gauss(x=x[i],mu=mus[j],sigma=sigmas[j])*priors[j]
    47     w_sum_wrt_j=np.sum(w,axis=0)#note j is the row index
    48     for j in range(m):
    49         w[j,:]=w[j,:]/w_sum_wrt_j
    50     return w
    51 def m_step(w,current_mus,x=x0):
    52     """
    53 
    54     :param w: m by n array, where m is # classes
    55     :return: mus: gaussian centers, an array of shape (m,)
    56              sigmas: gaussian standard deviations, an array of shape (m,)
    57     """
    58     m,n=w.shape
    59     mus=np.zeros(shape=(m))
    60     sigmas=np.zeros(shape=(m))
    61     for j in range(m):
    62         mus[j]=np.dot(
    63             w[j,:],x
    64         )
    65     mus/=np.sum(w,axis=1)
    66     for j in range(m):
    67         sigmas[j]=np.sqrt(np.dot(
    68             w[j, :], (x-current_mus[j])**2
    69         ))
    70     sigmas/=np.sqrt(np.sum(w,axis=1))
    71 
    72     priors=np.zeros(shape=(len(mus)))
    73     for i in range(n):
    74         tmp=list(map(
    75             gauss,[x[i]]*m,mus,sigmas
    76         ))
    77         tmpmaxindex,tmpmax=max(
    78             enumerate(tmp),key=operator.itemgetter(1)
    79         )
    80         # print(tmp)
    81         # print(tmpmaxindex)
    82         priors[tmpmaxindex]+=1/n
    83     return mus,sigmas,priors
    84 def solve(x=x0,priors=np.ones(len(mus0))/len(mus0)):
    85     # print("priors={}".format(priors))
    86     mus=mus0
    87     sigmas=sigmas0
    88     for k in range(500):
    89         w=e_step(mus=mus,sigmas=sigmas,x=x,priors=priors)
    90         mus,sigmas,priors=m_step(w,current_mus=mus,x=x0)
    91         print("k={},mus={},sigmas={},priors={}".format(k,mus,sigmas,priors))
    92 
    93 if __name__ == '__main__':
    94     solve()

    After 100 iterations, we get an approximation of the real model.

      1 /usr/local/bin/python3.5 /home/csdl/review/fulcrum/gmm/gmm.py
      2 k=0,mus=[ 0.81734343  1.27122747],sigmas=[ 1.60216343  1.33905931],priors=[ 0.32  0.68]
      3 k=1,mus=[ 0.73393663  1.21263431],sigmas=[ 1.48989073  1.27140643],priors=[ 0.35  0.65]
      4 k=2,mus=[ 0.72025041  1.24207148],sigmas=[ 1.47760392  1.25840835],priors=[ 0.36  0.64]
      5 k=3,mus=[ 0.69405554  1.2656654 ],sigmas=[ 1.47453155  1.2480128 ],priors=[ 0.36  0.64]
      6 k=4,mus=[ 0.65993336  1.28545741],sigmas=[ 1.47454151  1.238417  ],priors=[ 0.36  0.64]
      7 k=5,mus=[ 0.62512005  1.30527053],sigmas=[ 1.4739642   1.22830782],priors=[ 0.36  0.64]
      8 k=6,mus=[ 0.59009448  1.32522573],sigmas=[ 1.47230468  1.21788024],priors=[ 0.36  0.64]
      9 k=7,mus=[ 0.55504913  1.34523959],sigmas=[ 1.46932309  1.20732602],priors=[ 0.36  0.64]
     10 k=8,mus=[ 0.52016003  1.36521637],sigmas=[ 1.46489424  1.19678812],priors=[ 0.36  0.64]
     11 k=9,mus=[ 0.4855794   1.38507002],sigmas=[ 1.45897578  1.18636451],priors=[ 0.36  0.64]
     12 k=10,mus=[ 0.45142496  1.4047313 ],sigmas=[ 1.45158393  1.17611449],priors=[ 0.36  0.64]
     13 k=11,mus=[ 0.41777707  1.42414967],sigmas=[ 1.44277296  1.16606539],priors=[ 0.36  0.64]
     14 k=12,mus=[ 0.38468177  1.44329208],sigmas=[ 1.43261873  1.15621962],priors=[ 0.37  0.63]
     15 k=13,mus=[ 0.36595587  1.46867892],sigmas=[ 1.41990091  1.14409883],priors=[ 0.37  0.63]
     16 k=14,mus=[ 0.33654056  1.48870368],sigmas=[ 1.4082571   1.13343039],priors=[ 0.37  0.63]
     17 k=15,mus=[ 0.30597566  1.50763142],sigmas=[ 1.39543174  1.12335036],priors=[ 0.37  0.63]
     18 k=16,mus=[ 0.27568252  1.52609593],sigmas=[ 1.38137316  1.11360905],priors=[ 0.37  0.63]
     19 k=17,mus=[ 0.24588996  1.54419117],sigmas=[ 1.36625212  1.10407072],priors=[ 0.37  0.63]
     20 k=18,mus=[ 0.21664299  1.56192385],sigmas=[ 1.35022455  1.09464684],priors=[ 0.37  0.63]
     21 k=19,mus=[ 0.18796432  1.57928065],sigmas=[ 1.33342219  1.0852798 ],priors=[ 0.37  0.63]
     22 k=20,mus=[ 0.1598861   1.59623644],sigmas=[ 1.31596643  1.07593506],priors=[ 0.37  0.63]
     23 k=21,mus=[ 0.13245872  1.61275414],sigmas=[ 1.2979812   1.06659735],priors=[ 0.39  0.61]
     24 k=22,mus=[ 0.13549936  1.64420575],sigmas=[ 1.28163006  1.05238461],priors=[ 0.4  0.6]
     25 k=23,mus=[ 0.13362832  1.67212388],sigmas=[ 1.26763647  1.03761078],priors=[ 0.4  0.6]
     26 k=24,mus=[ 0.11750175  1.69125511],sigmas=[ 1.25330002  1.02525138],priors=[ 0.41  0.59]
     27 k=25,mus=[ 0.11224826  1.71494289],sigmas=[ 1.23950504  1.01230038],priors=[ 0.42  0.58]
     28 k=26,mus=[ 0.11153847  1.7395662 ],sigmas=[ 1.22728752  0.99889453],priors=[ 0.42  0.58]
     29 k=27,mus=[ 0.0999276   1.75644918],sigmas=[ 1.21474604  0.98770556],priors=[ 0.43  0.57]
     30 k=28,mus=[ 0.09911993  1.77770261],sigmas=[ 1.20375615  0.97601043],priors=[ 0.43  0.57]
     31 k=29,mus=[ 0.08991339  1.79234269],sigmas=[ 1.19274093  0.96620904],priors=[ 0.43  0.57]
     32 k=30,mus=[ 0.07854133  1.80401995],sigmas=[ 1.18163507  0.95803992],priors=[ 0.43  0.57]
     33 k=31,mus=[ 0.06708472  1.81391145],sigmas=[ 1.1708709  0.9510143],priors=[ 0.43  0.57]
     34 k=32,mus=[ 0.05629168  1.82248392],sigmas=[ 1.16077864  0.94483468],priors=[ 0.43  0.57]
     35 k=33,mus=[ 0.04644144  1.8299628 ],sigmas=[ 1.15153082  0.93934709],priors=[ 0.43  0.57]
     36 k=34,mus=[ 0.03761987  1.83648519],sigmas=[ 1.14319449  0.93447086],priors=[ 0.43  0.57]
     37 k=35,mus=[ 0.02982246  1.84215374],sigmas=[ 1.13577403  0.93015559],priors=[ 0.43  0.57]
     38 k=36,mus=[ 0.02299928  1.84705693],sigmas=[ 1.12923679  0.92636035],priors=[ 0.43  0.57]
     39 k=37,mus=[ 0.01707735  1.85127645],sigmas=[ 1.12352817  0.92304533],priors=[ 0.43  0.57]
     40 k=38,mus=[ 0.01197298  1.85488949],sigmas=[ 1.11858109  0.92016939],priors=[ 0.43  0.57]
     41 k=39,mus=[ 0.00759925  1.85796875],sigmas=[ 1.11432244  0.91769023],priors=[ 0.43  0.57]
     42 k=40,mus=[ 0.00387068  1.86058202],sigmas=[ 1.11067765  0.91556544],priors=[ 0.43  0.57]
     43 k=41,mus=[  7.06082311e-04   1.86279150e+00],sigmas=[ 1.10757393  0.91375369],priors=[ 0.43  0.57]
     44 k=42,mus=[-0.00196965  1.86465346],sigmas=[ 1.10494246  0.91221583],priors=[ 0.43  0.57]
     45 k=43,mus=[-0.00422464  1.86621814],sigmas=[ 1.10271971  0.91091554],priors=[ 0.43  0.57]
     46 k=44,mus=[-0.00611974  1.86752982],sigmas=[ 1.10084819  0.9098198 ],priors=[ 0.43  0.57]
     47 k=45,mus=[-0.00770859  1.86862714],sigmas=[ 1.09927671  0.90889909],priors=[ 0.43  0.57]
     48 k=46,mus=[-0.00903796  1.86954354],sigmas=[ 1.09796019  0.90812732],priors=[ 0.43  0.57]
     49 k=47,mus=[-0.01014832  1.87030773],sigmas=[ 1.09685943  0.90748172],priors=[ 0.43  0.57]
     50 k=48,mus=[-0.01107441  1.87094421],sigmas=[ 1.09594057  0.90694261],priors=[ 0.43  0.57]
     51 k=49,mus=[-0.01184586  1.87147378],sigmas=[ 1.09517461  0.90649307],priors=[ 0.43  0.57]
     52 k=50,mus=[-0.01248783  1.87191401],sigmas=[ 1.09453685  0.90611867],priors=[ 0.43  0.57]
     53 k=51,mus=[-0.01302159  1.87227973],sigmas=[ 1.09400634  0.90580718],priors=[ 0.43  0.57]
     54 k=52,mus=[-0.01346505  1.87258336],sigmas=[ 1.09356541  0.90554823],priors=[ 0.43  0.57]
     55 k=53,mus=[-0.01383328  1.87283531],sigmas=[ 1.09319917  0.90533313],priors=[ 0.43  0.57]
     56 k=54,mus=[-0.01413888  1.87304431],sigmas=[ 1.09289515  0.90515454],priors=[ 0.43  0.57]
     57 k=55,mus=[-0.0143924   1.87321761],sigmas=[ 1.09264288  0.90500635],priors=[ 0.43  0.57]
     58 k=56,mus=[-0.01460264  1.87336127],sigmas=[ 1.09243365  0.90488343],priors=[ 0.43  0.57]
     59 k=57,mus=[-0.01477693  1.87348033],sigmas=[ 1.09226016  0.9047815 ],priors=[ 0.43  0.57]
     60 k=58,mus=[-0.01492139  1.87357899],sigmas=[ 1.09211635  0.90469701],priors=[ 0.43  0.57]
     61 k=59,mus=[-0.0150411   1.87366073],sigmas=[ 1.09199717  0.90462698],priors=[ 0.43  0.57]
     62 k=60,mus=[-0.01514028  1.87372844],sigmas=[ 1.09189842  0.90456896],priors=[ 0.43  0.57]
     63 k=61,mus=[-0.01522245  1.87378452],sigmas=[ 1.09181661  0.90452088],priors=[ 0.43  0.57]
     64 k=62,mus=[-0.01529051  1.87383097],sigmas=[ 1.09174884  0.90448106],priors=[ 0.43  0.57]
     65 k=63,mus=[-0.01534687  1.87386944],sigmas=[ 1.0916927   0.90444807],priors=[ 0.43  0.57]
     66 k=64,mus=[-0.01539356  1.87390129],sigmas=[ 1.09164621  0.90442075],priors=[ 0.43  0.57]
     67 k=65,mus=[-0.01543222  1.87392767],sigmas=[ 1.09160771  0.90439813],priors=[ 0.43  0.57]
     68 k=66,mus=[-0.01546423  1.87394951],sigmas=[ 1.09157583  0.90437939],priors=[ 0.43  0.57]
     69 k=67,mus=[-0.01549074  1.87396759],sigmas=[ 1.09154943  0.90436388],priors=[ 0.43  0.57]
     70 k=68,mus=[-0.01551269  1.87398257],sigmas=[ 1.09152757  0.90435103],priors=[ 0.43  0.57]
     71 k=69,mus=[-0.01553086  1.87399496],sigmas=[ 1.09150947  0.9043404 ],priors=[ 0.43  0.57]
     72 k=70,mus=[-0.0155459   1.87400523],sigmas=[ 1.09149449  0.90433159],priors=[ 0.43  0.57]
     73 k=71,mus=[-0.01555836  1.87401373],sigmas=[ 1.09148208  0.9043243 ],priors=[ 0.43  0.57]
     74 k=72,mus=[-0.01556868  1.87402076],sigmas=[ 1.09147181  0.90431826],priors=[ 0.43  0.57]
     75 k=73,mus=[-0.01557722  1.87402659],sigmas=[ 1.0914633   0.90431327],priors=[ 0.43  0.57]
     76 k=74,mus=[-0.01558428  1.87403141],sigmas=[ 1.09145626  0.90430913],priors=[ 0.43  0.57]
     77 k=75,mus=[-0.01559014  1.8740354 ],sigmas=[ 1.09145043  0.9043057 ],priors=[ 0.43  0.57]
     78 k=76,mus=[-0.01559498  1.87403871],sigmas=[ 1.09144561  0.90430287],priors=[ 0.43  0.57]
     79 k=77,mus=[-0.01559899  1.87404144],sigmas=[ 1.09144161  0.90430052],priors=[ 0.43  0.57]
     80 k=78,mus=[-0.01560232  1.87404371],sigmas=[ 1.0914383   0.90429857],priors=[ 0.43  0.57]
     81 k=79,mus=[-0.01560506  1.87404558],sigmas=[ 1.09143556  0.90429696],priors=[ 0.43  0.57]
     82 k=80,mus=[-0.01560734  1.87404714],sigmas=[ 1.0914333   0.90429563],priors=[ 0.43  0.57]
     83 k=81,mus=[-0.01560923  1.87404842],sigmas=[ 1.09143142  0.90429453],priors=[ 0.43  0.57]
     84 k=82,mus=[-0.01561079  1.87404948],sigmas=[ 1.09142987  0.90429362],priors=[ 0.43  0.57]
     85 k=83,mus=[-0.01561208  1.87405037],sigmas=[ 1.09142858  0.90429286],priors=[ 0.43  0.57]
     86 k=84,mus=[-0.01561315  1.8740511 ],sigmas=[ 1.09142751  0.90429223],priors=[ 0.43  0.57]
     87 k=85,mus=[-0.01561403  1.8740517 ],sigmas=[ 1.09142663  0.90429172],priors=[ 0.43  0.57]
     88 k=86,mus=[-0.01561476  1.8740522 ],sigmas=[ 1.0914259   0.90429129],priors=[ 0.43  0.57]
     89 k=87,mus=[-0.01561537  1.87405261],sigmas=[ 1.0914253   0.90429093],priors=[ 0.43  0.57]
     90 k=88,mus=[-0.01561587  1.87405295],sigmas=[ 1.0914248   0.90429064],priors=[ 0.43  0.57]
     91 k=89,mus=[-0.01561629  1.87405324],sigmas=[ 1.09142438  0.90429039],priors=[ 0.43  0.57]
     92 k=90,mus=[-0.01561663  1.87405347],sigmas=[ 1.09142404  0.90429019],priors=[ 0.43  0.57]
     93 k=91,mus=[-0.01561692  1.87405367],sigmas=[ 1.09142376  0.90429003],priors=[ 0.43  0.57]
     94 k=92,mus=[-0.01561715  1.87405383],sigmas=[ 1.09142352  0.90428989],priors=[ 0.43  0.57]
     95 k=93,mus=[-0.01561735  1.87405396],sigmas=[ 1.09142333  0.90428977],priors=[ 0.43  0.57]
     96 k=94,mus=[-0.01561751  1.87405407],sigmas=[ 1.09142317  0.90428968],priors=[ 0.43  0.57]
     97 k=95,mus=[-0.01561764  1.87405416],sigmas=[ 1.09142303  0.9042896 ],priors=[ 0.43  0.57]
     98 k=96,mus=[-0.01561775  1.87405424],sigmas=[ 1.09142292  0.90428954],priors=[ 0.43  0.57]
     99 k=97,mus=[-0.01561785  1.8740543 ],sigmas=[ 1.09142283  0.90428948],priors=[ 0.43  0.57]
    100 k=98,mus=[-0.01561792  1.87405435],sigmas=[ 1.09142276  0.90428944],priors=[ 0.43  0.57]
    101 k=99,mus=[-0.01561799  1.8740544 ],sigmas=[ 1.09142269  0.9042894 ],priors=[ 0.43  0.57]
    102 
    103 Process finished with exit code 0

      In addition, a scikit-learn example can be found at http://scikit-learn.org/stable/modules/mixture.html

    [1] Ian Goodfellow. https://www.quora.com/Why-could-generative-models-help-with-unsupervised-learning/answer/Ian-Goodfellow?srid=hTUVm

  • 相关阅读:
    扩展欧几里得算法
    poj-3094-quicksum
    (floyd)佛洛伊德算法
    poj-3660-cows contest(不懂待定)
    poj-1056-IMMEDIATE DECODABILITY(字典)
    delete与delete[]的区别
    poj-1046-color me less
    SqlParameter 使用
    VS2010中出现无法嵌入互操作类型(转)
    fastreport代码转
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8313163.html
Copyright © 2011-2022 走看看