zoukankan      html  css  js  c++  java
  • Mixtures of Gaussians and the EM algorithms

    Acknowledgement to Stanford CS229.

    Generative modeling is itself a kind of unsupervised learning task[1]. Given unlabelled data, 

    To estimate the parameters, we can write the likelihood as 

    which is also

    The EM algorithm can solve this pdf estimation iteratively.

    An example is provided here. The data points are drawn from 2 gaussian distributions. 

     1 import numpy as np
     2 import operator
     3 np.random.seed(0)
     4 x0=np.random.normal(0,1,50)
     5 x0=np.concatenate((x0,np.random.normal(2,1,50)),
     6                  axis=0)
     7 
     8 mus0=np.array([
     9     0,1
    10 ])
    11 sigmas0=np.array([
    12     2,2
    13 ])
    14 def gauss(x,mu,sigma):
    15     """
    16 
    17     :param x:
    18     :param mu:
    19     :param sigma:
    20     :return: pdf(x)
    21     """
    22     # if np.abs((x-mu)/sigma)<1e-5:
    23     #     return
    24     # numerator=np.exp(
    25     #     -(x-mu)**2/(2*sigma**2)
    26     # )
    27     numerator=np.exp(
    28         -0.5*((x-mu)/sigma)**2
    29     )
    30     denominator=np.sqrt(2*np.pi*sigma**2)
    31     return numerator/denominator
    32 def e_step(mus=mus0,sigmas=sigmas0,x=x0,priors=np.ones(len(mus0))/len(mus0)):
    33     """
    34 
    35     :param mus: gaussian centers, an array of shape (m,)
    36     :param sigmas: gaussian standard deviations, an array of shape (m,)
    37     :param x: n samples with no labels
    38     :return: m by n array, where m is # classes
    39     """
    40     assert len(mus)==len(sigmas),"mus and sigmas doesn't have the same length"
    41     m=len(mus)
    42     n=len(x)
    43     w=np.zeros(shape=(m,n))
    44     for j in range(m):
    45         for i in range(n):
    46             w[j][i]=gauss(x=x[i],mu=mus[j],sigma=sigmas[j])*priors[j]
    47     w_sum_wrt_j=np.sum(w,axis=0)#note j is the row index
    48     for j in range(m):
    49         w[j,:]=w[j,:]/w_sum_wrt_j
    50     return w
    51 def m_step(w,current_mus,x=x0):
    52     """
    53 
    54     :param w: m by n array, where m is # classes
    55     :return: mus: gaussian centers, an array of shape (m,)
    56              sigmas: gaussian standard deviations, an array of shape (m,)
    57     """
    58     m,n=w.shape
    59     mus=np.zeros(shape=(m))
    60     sigmas=np.zeros(shape=(m))
    61     for j in range(m):
    62         mus[j]=np.dot(
    63             w[j,:],x
    64         )
    65     mus/=np.sum(w,axis=1)
    66     for j in range(m):
    67         sigmas[j]=np.sqrt(np.dot(
    68             w[j, :], (x-current_mus[j])**2
    69         ))
    70     sigmas/=np.sqrt(np.sum(w,axis=1))
    71 
    72     priors=np.zeros(shape=(len(mus)))
    73     for i in range(n):
    74         tmp=list(map(
    75             gauss,[x[i]]*m,mus,sigmas
    76         ))
    77         tmpmaxindex,tmpmax=max(
    78             enumerate(tmp),key=operator.itemgetter(1)
    79         )
    80         # print(tmp)
    81         # print(tmpmaxindex)
    82         priors[tmpmaxindex]+=1/n
    83     return mus,sigmas,priors
    84 def solve(x=x0,priors=np.ones(len(mus0))/len(mus0)):
    85     # print("priors={}".format(priors))
    86     mus=mus0
    87     sigmas=sigmas0
    88     for k in range(500):
    89         w=e_step(mus=mus,sigmas=sigmas,x=x,priors=priors)
    90         mus,sigmas,priors=m_step(w,current_mus=mus,x=x0)
    91         print("k={},mus={},sigmas={},priors={}".format(k,mus,sigmas,priors))
    92 
    93 if __name__ == '__main__':
    94     solve()

    After 100 iterations, we get an approximation of the real model.

      1 /usr/local/bin/python3.5 /home/csdl/review/fulcrum/gmm/gmm.py
      2 k=0,mus=[ 0.81734343  1.27122747],sigmas=[ 1.60216343  1.33905931],priors=[ 0.32  0.68]
      3 k=1,mus=[ 0.73393663  1.21263431],sigmas=[ 1.48989073  1.27140643],priors=[ 0.35  0.65]
      4 k=2,mus=[ 0.72025041  1.24207148],sigmas=[ 1.47760392  1.25840835],priors=[ 0.36  0.64]
      5 k=3,mus=[ 0.69405554  1.2656654 ],sigmas=[ 1.47453155  1.2480128 ],priors=[ 0.36  0.64]
      6 k=4,mus=[ 0.65993336  1.28545741],sigmas=[ 1.47454151  1.238417  ],priors=[ 0.36  0.64]
      7 k=5,mus=[ 0.62512005  1.30527053],sigmas=[ 1.4739642   1.22830782],priors=[ 0.36  0.64]
      8 k=6,mus=[ 0.59009448  1.32522573],sigmas=[ 1.47230468  1.21788024],priors=[ 0.36  0.64]
      9 k=7,mus=[ 0.55504913  1.34523959],sigmas=[ 1.46932309  1.20732602],priors=[ 0.36  0.64]
     10 k=8,mus=[ 0.52016003  1.36521637],sigmas=[ 1.46489424  1.19678812],priors=[ 0.36  0.64]
     11 k=9,mus=[ 0.4855794   1.38507002],sigmas=[ 1.45897578  1.18636451],priors=[ 0.36  0.64]
     12 k=10,mus=[ 0.45142496  1.4047313 ],sigmas=[ 1.45158393  1.17611449],priors=[ 0.36  0.64]
     13 k=11,mus=[ 0.41777707  1.42414967],sigmas=[ 1.44277296  1.16606539],priors=[ 0.36  0.64]
     14 k=12,mus=[ 0.38468177  1.44329208],sigmas=[ 1.43261873  1.15621962],priors=[ 0.37  0.63]
     15 k=13,mus=[ 0.36595587  1.46867892],sigmas=[ 1.41990091  1.14409883],priors=[ 0.37  0.63]
     16 k=14,mus=[ 0.33654056  1.48870368],sigmas=[ 1.4082571   1.13343039],priors=[ 0.37  0.63]
     17 k=15,mus=[ 0.30597566  1.50763142],sigmas=[ 1.39543174  1.12335036],priors=[ 0.37  0.63]
     18 k=16,mus=[ 0.27568252  1.52609593],sigmas=[ 1.38137316  1.11360905],priors=[ 0.37  0.63]
     19 k=17,mus=[ 0.24588996  1.54419117],sigmas=[ 1.36625212  1.10407072],priors=[ 0.37  0.63]
     20 k=18,mus=[ 0.21664299  1.56192385],sigmas=[ 1.35022455  1.09464684],priors=[ 0.37  0.63]
     21 k=19,mus=[ 0.18796432  1.57928065],sigmas=[ 1.33342219  1.0852798 ],priors=[ 0.37  0.63]
     22 k=20,mus=[ 0.1598861   1.59623644],sigmas=[ 1.31596643  1.07593506],priors=[ 0.37  0.63]
     23 k=21,mus=[ 0.13245872  1.61275414],sigmas=[ 1.2979812   1.06659735],priors=[ 0.39  0.61]
     24 k=22,mus=[ 0.13549936  1.64420575],sigmas=[ 1.28163006  1.05238461],priors=[ 0.4  0.6]
     25 k=23,mus=[ 0.13362832  1.67212388],sigmas=[ 1.26763647  1.03761078],priors=[ 0.4  0.6]
     26 k=24,mus=[ 0.11750175  1.69125511],sigmas=[ 1.25330002  1.02525138],priors=[ 0.41  0.59]
     27 k=25,mus=[ 0.11224826  1.71494289],sigmas=[ 1.23950504  1.01230038],priors=[ 0.42  0.58]
     28 k=26,mus=[ 0.11153847  1.7395662 ],sigmas=[ 1.22728752  0.99889453],priors=[ 0.42  0.58]
     29 k=27,mus=[ 0.0999276   1.75644918],sigmas=[ 1.21474604  0.98770556],priors=[ 0.43  0.57]
     30 k=28,mus=[ 0.09911993  1.77770261],sigmas=[ 1.20375615  0.97601043],priors=[ 0.43  0.57]
     31 k=29,mus=[ 0.08991339  1.79234269],sigmas=[ 1.19274093  0.96620904],priors=[ 0.43  0.57]
     32 k=30,mus=[ 0.07854133  1.80401995],sigmas=[ 1.18163507  0.95803992],priors=[ 0.43  0.57]
     33 k=31,mus=[ 0.06708472  1.81391145],sigmas=[ 1.1708709  0.9510143],priors=[ 0.43  0.57]
     34 k=32,mus=[ 0.05629168  1.82248392],sigmas=[ 1.16077864  0.94483468],priors=[ 0.43  0.57]
     35 k=33,mus=[ 0.04644144  1.8299628 ],sigmas=[ 1.15153082  0.93934709],priors=[ 0.43  0.57]
     36 k=34,mus=[ 0.03761987  1.83648519],sigmas=[ 1.14319449  0.93447086],priors=[ 0.43  0.57]
     37 k=35,mus=[ 0.02982246  1.84215374],sigmas=[ 1.13577403  0.93015559],priors=[ 0.43  0.57]
     38 k=36,mus=[ 0.02299928  1.84705693],sigmas=[ 1.12923679  0.92636035],priors=[ 0.43  0.57]
     39 k=37,mus=[ 0.01707735  1.85127645],sigmas=[ 1.12352817  0.92304533],priors=[ 0.43  0.57]
     40 k=38,mus=[ 0.01197298  1.85488949],sigmas=[ 1.11858109  0.92016939],priors=[ 0.43  0.57]
     41 k=39,mus=[ 0.00759925  1.85796875],sigmas=[ 1.11432244  0.91769023],priors=[ 0.43  0.57]
     42 k=40,mus=[ 0.00387068  1.86058202],sigmas=[ 1.11067765  0.91556544],priors=[ 0.43  0.57]
     43 k=41,mus=[  7.06082311e-04   1.86279150e+00],sigmas=[ 1.10757393  0.91375369],priors=[ 0.43  0.57]
     44 k=42,mus=[-0.00196965  1.86465346],sigmas=[ 1.10494246  0.91221583],priors=[ 0.43  0.57]
     45 k=43,mus=[-0.00422464  1.86621814],sigmas=[ 1.10271971  0.91091554],priors=[ 0.43  0.57]
     46 k=44,mus=[-0.00611974  1.86752982],sigmas=[ 1.10084819  0.9098198 ],priors=[ 0.43  0.57]
     47 k=45,mus=[-0.00770859  1.86862714],sigmas=[ 1.09927671  0.90889909],priors=[ 0.43  0.57]
     48 k=46,mus=[-0.00903796  1.86954354],sigmas=[ 1.09796019  0.90812732],priors=[ 0.43  0.57]
     49 k=47,mus=[-0.01014832  1.87030773],sigmas=[ 1.09685943  0.90748172],priors=[ 0.43  0.57]
     50 k=48,mus=[-0.01107441  1.87094421],sigmas=[ 1.09594057  0.90694261],priors=[ 0.43  0.57]
     51 k=49,mus=[-0.01184586  1.87147378],sigmas=[ 1.09517461  0.90649307],priors=[ 0.43  0.57]
     52 k=50,mus=[-0.01248783  1.87191401],sigmas=[ 1.09453685  0.90611867],priors=[ 0.43  0.57]
     53 k=51,mus=[-0.01302159  1.87227973],sigmas=[ 1.09400634  0.90580718],priors=[ 0.43  0.57]
     54 k=52,mus=[-0.01346505  1.87258336],sigmas=[ 1.09356541  0.90554823],priors=[ 0.43  0.57]
     55 k=53,mus=[-0.01383328  1.87283531],sigmas=[ 1.09319917  0.90533313],priors=[ 0.43  0.57]
     56 k=54,mus=[-0.01413888  1.87304431],sigmas=[ 1.09289515  0.90515454],priors=[ 0.43  0.57]
     57 k=55,mus=[-0.0143924   1.87321761],sigmas=[ 1.09264288  0.90500635],priors=[ 0.43  0.57]
     58 k=56,mus=[-0.01460264  1.87336127],sigmas=[ 1.09243365  0.90488343],priors=[ 0.43  0.57]
     59 k=57,mus=[-0.01477693  1.87348033],sigmas=[ 1.09226016  0.9047815 ],priors=[ 0.43  0.57]
     60 k=58,mus=[-0.01492139  1.87357899],sigmas=[ 1.09211635  0.90469701],priors=[ 0.43  0.57]
     61 k=59,mus=[-0.0150411   1.87366073],sigmas=[ 1.09199717  0.90462698],priors=[ 0.43  0.57]
     62 k=60,mus=[-0.01514028  1.87372844],sigmas=[ 1.09189842  0.90456896],priors=[ 0.43  0.57]
     63 k=61,mus=[-0.01522245  1.87378452],sigmas=[ 1.09181661  0.90452088],priors=[ 0.43  0.57]
     64 k=62,mus=[-0.01529051  1.87383097],sigmas=[ 1.09174884  0.90448106],priors=[ 0.43  0.57]
     65 k=63,mus=[-0.01534687  1.87386944],sigmas=[ 1.0916927   0.90444807],priors=[ 0.43  0.57]
     66 k=64,mus=[-0.01539356  1.87390129],sigmas=[ 1.09164621  0.90442075],priors=[ 0.43  0.57]
     67 k=65,mus=[-0.01543222  1.87392767],sigmas=[ 1.09160771  0.90439813],priors=[ 0.43  0.57]
     68 k=66,mus=[-0.01546423  1.87394951],sigmas=[ 1.09157583  0.90437939],priors=[ 0.43  0.57]
     69 k=67,mus=[-0.01549074  1.87396759],sigmas=[ 1.09154943  0.90436388],priors=[ 0.43  0.57]
     70 k=68,mus=[-0.01551269  1.87398257],sigmas=[ 1.09152757  0.90435103],priors=[ 0.43  0.57]
     71 k=69,mus=[-0.01553086  1.87399496],sigmas=[ 1.09150947  0.9043404 ],priors=[ 0.43  0.57]
     72 k=70,mus=[-0.0155459   1.87400523],sigmas=[ 1.09149449  0.90433159],priors=[ 0.43  0.57]
     73 k=71,mus=[-0.01555836  1.87401373],sigmas=[ 1.09148208  0.9043243 ],priors=[ 0.43  0.57]
     74 k=72,mus=[-0.01556868  1.87402076],sigmas=[ 1.09147181  0.90431826],priors=[ 0.43  0.57]
     75 k=73,mus=[-0.01557722  1.87402659],sigmas=[ 1.0914633   0.90431327],priors=[ 0.43  0.57]
     76 k=74,mus=[-0.01558428  1.87403141],sigmas=[ 1.09145626  0.90430913],priors=[ 0.43  0.57]
     77 k=75,mus=[-0.01559014  1.8740354 ],sigmas=[ 1.09145043  0.9043057 ],priors=[ 0.43  0.57]
     78 k=76,mus=[-0.01559498  1.87403871],sigmas=[ 1.09144561  0.90430287],priors=[ 0.43  0.57]
     79 k=77,mus=[-0.01559899  1.87404144],sigmas=[ 1.09144161  0.90430052],priors=[ 0.43  0.57]
     80 k=78,mus=[-0.01560232  1.87404371],sigmas=[ 1.0914383   0.90429857],priors=[ 0.43  0.57]
     81 k=79,mus=[-0.01560506  1.87404558],sigmas=[ 1.09143556  0.90429696],priors=[ 0.43  0.57]
     82 k=80,mus=[-0.01560734  1.87404714],sigmas=[ 1.0914333   0.90429563],priors=[ 0.43  0.57]
     83 k=81,mus=[-0.01560923  1.87404842],sigmas=[ 1.09143142  0.90429453],priors=[ 0.43  0.57]
     84 k=82,mus=[-0.01561079  1.87404948],sigmas=[ 1.09142987  0.90429362],priors=[ 0.43  0.57]
     85 k=83,mus=[-0.01561208  1.87405037],sigmas=[ 1.09142858  0.90429286],priors=[ 0.43  0.57]
     86 k=84,mus=[-0.01561315  1.8740511 ],sigmas=[ 1.09142751  0.90429223],priors=[ 0.43  0.57]
     87 k=85,mus=[-0.01561403  1.8740517 ],sigmas=[ 1.09142663  0.90429172],priors=[ 0.43  0.57]
     88 k=86,mus=[-0.01561476  1.8740522 ],sigmas=[ 1.0914259   0.90429129],priors=[ 0.43  0.57]
     89 k=87,mus=[-0.01561537  1.87405261],sigmas=[ 1.0914253   0.90429093],priors=[ 0.43  0.57]
     90 k=88,mus=[-0.01561587  1.87405295],sigmas=[ 1.0914248   0.90429064],priors=[ 0.43  0.57]
     91 k=89,mus=[-0.01561629  1.87405324],sigmas=[ 1.09142438  0.90429039],priors=[ 0.43  0.57]
     92 k=90,mus=[-0.01561663  1.87405347],sigmas=[ 1.09142404  0.90429019],priors=[ 0.43  0.57]
     93 k=91,mus=[-0.01561692  1.87405367],sigmas=[ 1.09142376  0.90429003],priors=[ 0.43  0.57]
     94 k=92,mus=[-0.01561715  1.87405383],sigmas=[ 1.09142352  0.90428989],priors=[ 0.43  0.57]
     95 k=93,mus=[-0.01561735  1.87405396],sigmas=[ 1.09142333  0.90428977],priors=[ 0.43  0.57]
     96 k=94,mus=[-0.01561751  1.87405407],sigmas=[ 1.09142317  0.90428968],priors=[ 0.43  0.57]
     97 k=95,mus=[-0.01561764  1.87405416],sigmas=[ 1.09142303  0.9042896 ],priors=[ 0.43  0.57]
     98 k=96,mus=[-0.01561775  1.87405424],sigmas=[ 1.09142292  0.90428954],priors=[ 0.43  0.57]
     99 k=97,mus=[-0.01561785  1.8740543 ],sigmas=[ 1.09142283  0.90428948],priors=[ 0.43  0.57]
    100 k=98,mus=[-0.01561792  1.87405435],sigmas=[ 1.09142276  0.90428944],priors=[ 0.43  0.57]
    101 k=99,mus=[-0.01561799  1.8740544 ],sigmas=[ 1.09142269  0.9042894 ],priors=[ 0.43  0.57]
    102 
    103 Process finished with exit code 0

      In addition, a scikit-learn example can be found at http://scikit-learn.org/stable/modules/mixture.html

    [1] Ian Goodfellow. https://www.quora.com/Why-could-generative-models-help-with-unsupervised-learning/answer/Ian-Goodfellow?srid=hTUVm

  • 相关阅读:
    Swift3 重写一个带占位符的textView
    Swift3 使用系统UIAlertView方法做吐司效果
    Swift3 页面顶部实现拉伸效果代码
    Swift3 倒计时按钮扩展
    iOS 获取当前对象所在的VC
    SpringBoot在IDEA下使用JPA
    hibernate 异常a different object with the same identifier value was already associated with the session
    SpringCloud IDEA 教学 番外篇 后台运行Eureka服务注册中心
    SpringCloud IDEA 教学 (五) 断路器控制台(HystrixDashboard)
    SpringCloud IDEA 教学 (四) 断路器(Hystrix)
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8313163.html
Copyright © 2011-2022 走看看