zoukankan      html  css  js  c++  java
  • beam_search 和 viterbi算法的区别

    1 beam search

    beam search 在每次预测的时候是选择概率最高的top_k个路径。

    要点:

    • 是基于贪心算法的思想,当k = 1时就是贪心算法
    • 常用于搜索空间非常大的情况,如语言生成任务,每一步选择一个词,而词表非常大,beam search可以大大减少计算量
    • beam search 将概率较低的分支删除,大大减少了搜索空间,其得到的解是一个近似解,不是全局最优解。
    • 时间复杂度为O(TKN)

    python实现一个简单的beam search

    序列长度为10,词典大小为5的单词

    # define a sequence of 10 words over a vocab of 5 words
    data = [[0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1]]
    
    from numpy import *
    def beam_search(data,k):
        """
        k:beam_size的大小
        """
        sequences = [[list(),1.0]]
        for row in data:   # 这个row_data可以看作时间t
            print("
    ")
            print("row的值为:",row)
            print(sequences,len(sequences))
            all_candidates = list()
            for i in range(len(sequences)):
                seq,score = sequences[i]
                print("seq的值为{},score的值为{}".format(seq,score))
                for j in range(len(row)):
                    candidates = [seq +[j],score*-log(row[j])]
                    all_candidates.append(candidates)
                print("all_candidates的值为:",all_candidates)
            ordered = sorted(all_candidates,key = lambda tup:tup[1]) # 之前取了副对数,所以这里为升序排列的
            print("排序之后的顺序为:",ordered)
            
            sequences = ordered[:k]
        return sequences
    
    beam_search(array(data),3)
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[], 1.0]] 1
    seq的值为[],score的值为1.0
    all_candidates的值为: [[[0], 2.3025850929940455], [[1], 1.6094379124341003], [[2], 1.2039728043259361], [[3], 0.916290731874155], [[4], 0.6931471805599453]]
    排序之后的顺序为: [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361], [[1], 1.6094379124341003], [[0], 2.3025850929940455]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]] 3
    seq的值为[4],score的值为0.6931471805599453
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182]]
    seq的值为[3],score的值为0.916290731874155
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033]]
    seq的值为[2],score的值为1.2039728043259361
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033], [[2, 0], 0.8345303547893733], [[2, 1], 1.1031891220323908], [[2, 2], 1.4495505135564588], [[2, 3], 1.937719476821764], [[2, 4], 2.7722498316111372]]
    排序之后的顺序为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[2, 0], 0.8345303547893733], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[2, 1], 1.1031891220323908], [[4, 3], 1.1155773512899807], [[2, 2], 1.4495505135564588], [[3, 3], 1.474713042690254], [[4, 4], 1.596030365208182], [[2, 3], 1.937719476821764], [[3, 4], 2.109837380062033], [[2, 4], 2.7722498316111372]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]] 3
    seq的值为[4, 0],score的值为0.4804530139182014
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944]]
    seq的值为[4, 1],score的值为0.6351243373717793
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523]]
    seq的值为[3, 0],score的值为0.6351243373717793
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523], [[3, 0, 0], 1.46242783142998], [[3, 0, 1], 1.0221931876757278], [[3, 0, 2], 0.7646724295611531], [[3, 0, 3], 0.5819585439214754], [[3, 0, 4], 0.4402346437542523]]
    排序之后的顺序为: [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523], [[3, 0, 4], 0.4402346437542523], [[4, 0, 2], 0.5784523625139449], [[4, 1, 3], 0.5819585439214754], [[3, 0, 3], 0.5819585439214754], [[4, 1, 2], 0.7646724295611531], [[3, 0, 2], 0.7646724295611531], [[4, 0, 1], 0.7732592957431818], [[4, 1, 1], 1.0221931876757278], [[3, 0, 1], 1.0221931876757278], [[4, 0, 0], 1.1062839477321111], [[4, 1, 0], 1.46242783142998], [[3, 0, 0], 1.46242783142998]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]] 3
    seq的值为[4, 0, 4],score的值为0.33302465198892944
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388]]
    seq的值为[4, 0, 3],score的值为0.4402346437542523
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855]]
    seq的值为[4, 1, 4],score的值为0.4402346437542523
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 0], 0.3051474021030719], [[4, 1, 4, 1], 0.40338292392194175], [[4, 1, 4, 2], 0.5300305386022366], [[4, 1, 4, 3], 0.7085303260250136], [[4, 1, 4, 4], 1.0136777281280855]]
    排序之后的顺序为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 3, 1], 0.40338292392194175], [[4, 1, 4, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 1, 4, 2], 0.5300305386022366], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 3, 3], 0.7085303260250136], [[4, 1, 4, 3], 0.7085303260250136], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 4], 1.0136777281280855]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719]] 3
    seq的值为[4, 0, 4, 0],score的值为0.23083509858308343
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413]]
    seq的值为[4, 0, 3, 0],score的值为0.3051474021030719
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622]]
    seq的值为[4, 1, 4, 0],score的值为0.3051474021030719
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 4], 0.21151206142293622]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 0], 0.7026278592483932]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622]] 3
    seq的值为[4, 0, 4, 0, 4],score的值为0.1600026977571413
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536]]
    seq的值为[4, 0, 3, 0, 4],score的值为0.21151206142293622
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938]]
    seq的值为[4, 1, 4, 0, 4],score的值为0.21151206142293622
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302]] 3
    seq的值为[4, 0, 4, 0, 4, 0],score的值为0.11090541883234757
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158]]
    seq的值为[4, 0, 4, 0, 4, 1],score的值为0.1466089890297302
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]]
    seq的值为[4, 0, 3, 0, 4, 0],score的值为0.1466089890297302
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4],score的值为0.07687377837246158
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808]]
    seq的值为[4, 0, 4, 0, 4, 0, 3],score的值为0.10162160739070145
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266]]
    seq的值为[4, 0, 4, 0, 4, 1, 4],score的值为0.10162160739070145
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0],score的值为0.05328484273786184
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 1],score的值为0.07043873064683441
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]]
    seq的值为[4, 0, 4, 0, 4, 0, 3, 0],score的值为0.07043873064683441
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 4],score的值为0.03693423851032901
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 3],score的值为0.04882440755007468
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 1, 4],score的值为0.04882440755007468
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
    
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
     [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
     [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]
    

    1.1 贪心算法(k=1)

    def greedy_search(data):
        return [argmax(row) for row in data]
    
    greedy_search(data)
    
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
    

    2 viterbi算法

    要点:

    • 基于动态规划的思想
    • 每一步是根据上一步全部可能选择的最高概率推测当前所有选择的最高概率,保证有全局最优解
    • 适合搜索宽度较小的图,即每一步选择较少的时候
    • 时间复杂度为O(TNN)
    import numpy as np
    def viterbi(data):
        row,col = data.shape
        sigma = np.zeros((row,col))  # 存储最大的概率
        phi = np.zeros((row,col))    # 存储概率最大的索引
        sigma[0] = data[0]
        
        for i in range(1,row):
            for k in range(len(sigma[i-1])):
                tmp = float("-inf")
                for j in range(col):
                    if sigma[i-1][j]*data[i][k]>tmp:
                        tmp = sigma[i-1][j]*data[i][k]
                        index = j
    #                 tmp = max(tmp,sigma[i-1][k]*data[i][j])
                sigma[i][k] = tmp
                phi[i][k] = index
        # 回溯
        print(sigma)
        ans = [0]*row
        i_T = argmax(sigma[-1])
        i_t = i_T
        for t in range(row-2,-1,-1):
            i_t = int(phi[t+1][i_t])
            ans[t] = i_t
        ans[-1] = i_T
        return ans
            
    
    viterbi(data)
    
    [[1.000000e-01 2.000000e-01 3.000000e-01 4.000000e-01 5.000000e-01]
     [2.500000e-01 2.000000e-01 1.500000e-01 1.000000e-01 5.000000e-02]
     [2.500000e-02 5.000000e-02 7.500000e-02 1.000000e-01 1.250000e-01]
     [6.250000e-02 5.000000e-02 3.750000e-02 2.500000e-02 1.250000e-02]
     [6.250000e-03 1.250000e-02 1.875000e-02 2.500000e-02 3.125000e-02]
     [1.562500e-02 1.250000e-02 9.375000e-03 6.250000e-03 3.125000e-03]
     [1.562500e-03 3.125000e-03 4.687500e-03 6.250000e-03 7.812500e-03]
     [3.906250e-03 3.125000e-03 2.343750e-03 1.562500e-03 7.812500e-04]
     [3.906250e-04 7.812500e-04 1.171875e-03 1.562500e-03 1.953125e-03]
     [9.765625e-04 7.812500e-04 5.859375e-04 3.906250e-04 1.953125e-04]]
    
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
    
  • 相关阅读:
    JavaScript的执行
    关于k阶裴波那契序列的两种解法
    科普 eclipse中的Java build
    [BZOJ 1037] 生日聚会Party
    [POJ 1185] 炮兵阵地
    [POJ 1935] Journey
    [POJ 2397] Spiderman
    [POJ 2373][BZOJ 1986] Dividing the Path
    [POJ 3378] Crazy Thairs
    [POJ 2329] Nearest number-2
  • 原文地址:https://www.cnblogs.com/zhou-lin/p/15016429.html
Copyright © 2011-2022 走看看