zoukankan      html  css  js  c++  java
  • beam_search 和 viterbi算法的区别

    1 beam search

    beam search 在每次预测的时候是选择概率最高的top_k个路径。

    要点:

    • 是基于贪心算法的思想,当k = 1时就是贪心算法
    • 常用于搜索空间非常大的情况,如语言生成任务,每一步选择一个词,而词表非常大,beam search可以大大减少计算量
    • beam search 将概率较低的分支删除,大大减少了搜索空间,其得到的解是一个近似解,不是全局最优解。
    • 时间复杂度为O(TKN)

    python实现一个简单的beam search

    序列长度为10,词典大小为5的单词

    # define a sequence of 10 words over a vocab of 5 words
    data = [[0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1]]
    
    from numpy import *
    def beam_search(data,k):
        """
        k:beam_size的大小
        """
        sequences = [[list(),1.0]]
        for row in data:   # 这个row_data可以看作时间t
            print("
    ")
            print("row的值为:",row)
            print(sequences,len(sequences))
            all_candidates = list()
            for i in range(len(sequences)):
                seq,score = sequences[i]
                print("seq的值为{},score的值为{}".format(seq,score))
                for j in range(len(row)):
                    candidates = [seq +[j],score*-log(row[j])]
                    all_candidates.append(candidates)
                print("all_candidates的值为:",all_candidates)
            ordered = sorted(all_candidates,key = lambda tup:tup[1]) # 之前取了副对数,所以这里为升序排列的
            print("排序之后的顺序为:",ordered)
            
            sequences = ordered[:k]
        return sequences
    
    beam_search(array(data),3)
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[], 1.0]] 1
    seq的值为[],score的值为1.0
    all_candidates的值为: [[[0], 2.3025850929940455], [[1], 1.6094379124341003], [[2], 1.2039728043259361], [[3], 0.916290731874155], [[4], 0.6931471805599453]]
    排序之后的顺序为: [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361], [[1], 1.6094379124341003], [[0], 2.3025850929940455]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]] 3
    seq的值为[4],score的值为0.6931471805599453
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182]]
    seq的值为[3],score的值为0.916290731874155
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033]]
    seq的值为[2],score的值为1.2039728043259361
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033], [[2, 0], 0.8345303547893733], [[2, 1], 1.1031891220323908], [[2, 2], 1.4495505135564588], [[2, 3], 1.937719476821764], [[2, 4], 2.7722498316111372]]
    排序之后的顺序为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[2, 0], 0.8345303547893733], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[2, 1], 1.1031891220323908], [[4, 3], 1.1155773512899807], [[2, 2], 1.4495505135564588], [[3, 3], 1.474713042690254], [[4, 4], 1.596030365208182], [[2, 3], 1.937719476821764], [[3, 4], 2.109837380062033], [[2, 4], 2.7722498316111372]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]] 3
    seq的值为[4, 0],score的值为0.4804530139182014
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944]]
    seq的值为[4, 1],score的值为0.6351243373717793
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523]]
    seq的值为[3, 0],score的值为0.6351243373717793
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523], [[3, 0, 0], 1.46242783142998], [[3, 0, 1], 1.0221931876757278], [[3, 0, 2], 0.7646724295611531], [[3, 0, 3], 0.5819585439214754], [[3, 0, 4], 0.4402346437542523]]
    排序之后的顺序为: [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523], [[3, 0, 4], 0.4402346437542523], [[4, 0, 2], 0.5784523625139449], [[4, 1, 3], 0.5819585439214754], [[3, 0, 3], 0.5819585439214754], [[4, 1, 2], 0.7646724295611531], [[3, 0, 2], 0.7646724295611531], [[4, 0, 1], 0.7732592957431818], [[4, 1, 1], 1.0221931876757278], [[3, 0, 1], 1.0221931876757278], [[4, 0, 0], 1.1062839477321111], [[4, 1, 0], 1.46242783142998], [[3, 0, 0], 1.46242783142998]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]] 3
    seq的值为[4, 0, 4],score的值为0.33302465198892944
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388]]
    seq的值为[4, 0, 3],score的值为0.4402346437542523
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855]]
    seq的值为[4, 1, 4],score的值为0.4402346437542523
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 0], 0.3051474021030719], [[4, 1, 4, 1], 0.40338292392194175], [[4, 1, 4, 2], 0.5300305386022366], [[4, 1, 4, 3], 0.7085303260250136], [[4, 1, 4, 4], 1.0136777281280855]]
    排序之后的顺序为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 3, 1], 0.40338292392194175], [[4, 1, 4, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 1, 4, 2], 0.5300305386022366], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 3, 3], 0.7085303260250136], [[4, 1, 4, 3], 0.7085303260250136], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 4], 1.0136777281280855]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719]] 3
    seq的值为[4, 0, 4, 0],score的值为0.23083509858308343
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413]]
    seq的值为[4, 0, 3, 0],score的值为0.3051474021030719
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622]]
    seq的值为[4, 1, 4, 0],score的值为0.3051474021030719
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 4], 0.21151206142293622]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 0], 0.7026278592483932]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622]] 3
    seq的值为[4, 0, 4, 0, 4],score的值为0.1600026977571413
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536]]
    seq的值为[4, 0, 3, 0, 4],score的值为0.21151206142293622
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938]]
    seq的值为[4, 1, 4, 0, 4],score的值为0.21151206142293622
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302]] 3
    seq的值为[4, 0, 4, 0, 4, 0],score的值为0.11090541883234757
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158]]
    seq的值为[4, 0, 4, 0, 4, 1],score的值为0.1466089890297302
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]]
    seq的值为[4, 0, 3, 0, 4, 0],score的值为0.1466089890297302
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4],score的值为0.07687377837246158
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808]]
    seq的值为[4, 0, 4, 0, 4, 0, 3],score的值为0.10162160739070145
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266]]
    seq的值为[4, 0, 4, 0, 4, 1, 4],score的值为0.10162160739070145
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0],score的值为0.05328484273786184
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 1],score的值为0.07043873064683441
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]]
    seq的值为[4, 0, 4, 0, 4, 0, 3, 0],score的值为0.07043873064683441
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 4],score的值为0.03693423851032901
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 3],score的值为0.04882440755007468
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 1, 4],score的值为0.04882440755007468
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
    
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
     [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
     [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]
    

    1.1 贪心算法(k=1)

    def greedy_search(data):
        return [argmax(row) for row in data]
    
    greedy_search(data)
    
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
    

    2 viterbi算法

    要点:

    • 基于动态规划的思想
    • 每一步是根据上一步全部可能选择的最高概率推测当前所有选择的最高概率,保证有全局最优解
    • 适合搜索宽度较小的图,即每一步选择较少的时候
    • 时间复杂度为O(TNN)
    import numpy as np
    def viterbi(data):
        row,col = data.shape
        sigma = np.zeros((row,col))  # 存储最大的概率
        phi = np.zeros((row,col))    # 存储概率最大的索引
        sigma[0] = data[0]
        
        for i in range(1,row):
            for k in range(len(sigma[i-1])):
                tmp = float("-inf")
                for j in range(col):
                    if sigma[i-1][j]*data[i][k]>tmp:
                        tmp = sigma[i-1][j]*data[i][k]
                        index = j
    #                 tmp = max(tmp,sigma[i-1][k]*data[i][j])
                sigma[i][k] = tmp
                phi[i][k] = index
        # 回溯
        print(sigma)
        ans = [0]*row
        i_T = argmax(sigma[-1])
        i_t = i_T
        for t in range(row-2,-1,-1):
            i_t = int(phi[t+1][i_t])
            ans[t] = i_t
        ans[-1] = i_T
        return ans
            
    
    viterbi(data)
    
    [[1.000000e-01 2.000000e-01 3.000000e-01 4.000000e-01 5.000000e-01]
     [2.500000e-01 2.000000e-01 1.500000e-01 1.000000e-01 5.000000e-02]
     [2.500000e-02 5.000000e-02 7.500000e-02 1.000000e-01 1.250000e-01]
     [6.250000e-02 5.000000e-02 3.750000e-02 2.500000e-02 1.250000e-02]
     [6.250000e-03 1.250000e-02 1.875000e-02 2.500000e-02 3.125000e-02]
     [1.562500e-02 1.250000e-02 9.375000e-03 6.250000e-03 3.125000e-03]
     [1.562500e-03 3.125000e-03 4.687500e-03 6.250000e-03 7.812500e-03]
     [3.906250e-03 3.125000e-03 2.343750e-03 1.562500e-03 7.812500e-04]
     [3.906250e-04 7.812500e-04 1.171875e-03 1.562500e-03 1.953125e-03]
     [9.765625e-04 7.812500e-04 5.859375e-04 3.906250e-04 1.953125e-04]]
    
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
    
  • 相关阅读:
    Linux下sed,awk,grep,cut,find学习笔记
    Python文件处理(1)
    KMP详解
    Java引用详解
    解决安卓中页脚被输入法顶起的问题
    解决swfupload上传控件文件名中文乱码问题 三种方法 flash及最新版本11.8.800.168
    null id in entry (don't flush the Session after an exception occurs)
    HQL中的Like查询需要注意的地方
    spring mvc controller间跳转 重定向 传参
    node to traverse cannot be null!
  • 原文地址:https://www.cnblogs.com/zhou-lin/p/15016429.html
Copyright © 2011-2022 走看看