zoukankan      html  css  js  c++  java
  • beam_search 和 viterbi算法的区别

    1 beam search

    beam search 在每次预测的时候是选择概率最高的top_k个路径。

    要点:

    • 是基于贪心算法的思想,当k = 1时就是贪心算法
    • 常用于搜索空间非常大的情况,如语言生成任务,每一步选择一个词,而词表非常大,beam search可以大大减少计算量
    • beam search 将概率较低的分支删除,大大减少了搜索空间,其得到的解是一个近似解,不是全局最优解。
    • 时间复杂度为O(TKN)

    python实现一个简单的beam search

    序列长度为10,词典大小为5的单词

    # define a sequence of 10 words over a vocab of 5 words
    data = [[0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1]]
    
    from numpy import *
    def beam_search(data,k):
        """
        k:beam_size的大小
        """
        sequences = [[list(),1.0]]
        for row in data:   # 这个row_data可以看作时间t
            print("
    ")
            print("row的值为:",row)
            print(sequences,len(sequences))
            all_candidates = list()
            for i in range(len(sequences)):
                seq,score = sequences[i]
                print("seq的值为{},score的值为{}".format(seq,score))
                for j in range(len(row)):
                    candidates = [seq +[j],score*-log(row[j])]
                    all_candidates.append(candidates)
                print("all_candidates的值为:",all_candidates)
            ordered = sorted(all_candidates,key = lambda tup:tup[1]) # 之前取了副对数,所以这里为升序排列的
            print("排序之后的顺序为:",ordered)
            
            sequences = ordered[:k]
        return sequences
    
    beam_search(array(data),3)
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[], 1.0]] 1
    seq的值为[],score的值为1.0
    all_candidates的值为: [[[0], 2.3025850929940455], [[1], 1.6094379124341003], [[2], 1.2039728043259361], [[3], 0.916290731874155], [[4], 0.6931471805599453]]
    排序之后的顺序为: [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361], [[1], 1.6094379124341003], [[0], 2.3025850929940455]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]] 3
    seq的值为[4],score的值为0.6931471805599453
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182]]
    seq的值为[3],score的值为0.916290731874155
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033]]
    seq的值为[2],score的值为1.2039728043259361
    all_candidates的值为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[4, 3], 1.1155773512899807], [[4, 4], 1.596030365208182], [[3, 0], 0.6351243373717793], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[3, 3], 1.474713042690254], [[3, 4], 2.109837380062033], [[2, 0], 0.8345303547893733], [[2, 1], 1.1031891220323908], [[2, 2], 1.4495505135564588], [[2, 3], 1.937719476821764], [[2, 4], 2.7722498316111372]]
    排序之后的顺序为: [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793], [[4, 2], 0.8345303547893733], [[2, 0], 0.8345303547893733], [[3, 1], 0.8395887053184746], [[3, 2], 1.1031891220323908], [[2, 1], 1.1031891220323908], [[4, 3], 1.1155773512899807], [[2, 2], 1.4495505135564588], [[3, 3], 1.474713042690254], [[4, 4], 1.596030365208182], [[2, 3], 1.937719476821764], [[3, 4], 2.109837380062033], [[2, 4], 2.7722498316111372]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]] 3
    seq的值为[4, 0],score的值为0.4804530139182014
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944]]
    seq的值为[4, 1],score的值为0.6351243373717793
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523]]
    seq的值为[3, 0],score的值为0.6351243373717793
    all_candidates的值为: [[[4, 0, 0], 1.1062839477321111], [[4, 0, 1], 0.7732592957431818], [[4, 0, 2], 0.5784523625139449], [[4, 0, 3], 0.4402346437542523], [[4, 0, 4], 0.33302465198892944], [[4, 1, 0], 1.46242783142998], [[4, 1, 1], 1.0221931876757278], [[4, 1, 2], 0.7646724295611531], [[4, 1, 3], 0.5819585439214754], [[4, 1, 4], 0.4402346437542523], [[3, 0, 0], 1.46242783142998], [[3, 0, 1], 1.0221931876757278], [[3, 0, 2], 0.7646724295611531], [[3, 0, 3], 0.5819585439214754], [[3, 0, 4], 0.4402346437542523]]
    排序之后的顺序为: [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523], [[3, 0, 4], 0.4402346437542523], [[4, 0, 2], 0.5784523625139449], [[4, 1, 3], 0.5819585439214754], [[3, 0, 3], 0.5819585439214754], [[4, 1, 2], 0.7646724295611531], [[3, 0, 2], 0.7646724295611531], [[4, 0, 1], 0.7732592957431818], [[4, 1, 1], 1.0221931876757278], [[3, 0, 1], 1.0221931876757278], [[4, 0, 0], 1.1062839477321111], [[4, 1, 0], 1.46242783142998], [[3, 0, 0], 1.46242783142998]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]] 3
    seq的值为[4, 0, 4],score的值为0.33302465198892944
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388]]
    seq的值为[4, 0, 3],score的值为0.4402346437542523
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855]]
    seq的值为[4, 1, 4],score的值为0.4402346437542523
    all_candidates的值为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 0], 0.3051474021030719], [[4, 0, 3, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 0, 3, 3], 0.7085303260250136], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 0], 0.3051474021030719], [[4, 1, 4, 1], 0.40338292392194175], [[4, 1, 4, 2], 0.5300305386022366], [[4, 1, 4, 3], 0.7085303260250136], [[4, 1, 4, 4], 1.0136777281280855]]
    排序之后的顺序为: [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719], [[4, 0, 4, 1], 0.30514740210307195], [[4, 0, 4, 2], 0.40095262416478034], [[4, 0, 3, 1], 0.40338292392194175], [[4, 1, 4, 1], 0.40338292392194175], [[4, 0, 3, 2], 0.5300305386022366], [[4, 1, 4, 2], 0.5300305386022366], [[4, 0, 4, 3], 0.5359825006861554], [[4, 0, 3, 3], 0.7085303260250136], [[4, 1, 4, 3], 0.7085303260250136], [[4, 0, 4, 4], 0.7668175992692388], [[4, 0, 3, 4], 1.0136777281280855], [[4, 1, 4, 4], 1.0136777281280855]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719]] 3
    seq的值为[4, 0, 4, 0],score的值为0.23083509858308343
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413]]
    seq的值为[4, 0, 3, 0],score的值为0.3051474021030719
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622]]
    seq的值为[4, 1, 4, 0],score的值为0.3051474021030719
    all_candidates的值为: [[[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 4], 0.21151206142293622]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622], [[4, 0, 4, 0, 3], 0.21151206142293624], [[4, 0, 4, 0, 2], 0.2779191809779289], [[4, 0, 3, 0, 3], 0.2796037364025208], [[4, 1, 4, 0, 3], 0.2796037364025208], [[4, 0, 3, 0, 2], 0.3673891734428095], [[4, 1, 4, 0, 2], 0.3673891734428095], [[4, 0, 4, 0, 1], 0.37151475918007754], [[4, 0, 3, 0, 1], 0.491115797825457], [[4, 1, 4, 0, 1], 0.491115797825457], [[4, 0, 4, 0, 0], 0.5315174569372189], [[4, 0, 3, 0, 0], 0.7026278592483932], [[4, 1, 4, 0, 0], 0.7026278592483932]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622]] 3
    seq的值为[4, 0, 4, 0, 4],score的值为0.1600026977571413
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536]]
    seq的值为[4, 0, 3, 0, 4],score的值为0.21151206142293622
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938]]
    seq的值为[4, 1, 4, 0, 4],score的值为0.21151206142293622
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302], [[4, 1, 4, 0, 4, 0], 0.1466089890297302], [[4, 0, 4, 0, 4, 2], 0.19263889671838058], [[4, 0, 3, 0, 4, 1], 0.19380654156143345], [[4, 1, 4, 0, 4, 1], 0.19380654156143345], [[4, 0, 3, 0, 4, 2], 0.2546547697401322], [[4, 1, 4, 0, 4, 2], 0.2546547697401322], [[4, 0, 4, 0, 4, 3], 0.2575144078620778], [[4, 0, 3, 0, 4, 3], 0.34041553059116364], [[4, 1, 4, 0, 4, 3], 0.34041553059116364], [[4, 0, 4, 0, 4, 4], 0.36841982669442536], [[4, 0, 3, 0, 4, 4], 0.4870245196208938], [[4, 1, 4, 0, 4, 4], 0.4870245196208938]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302]] 3
    seq的值为[4, 0, 4, 0, 4, 0],score的值为0.11090541883234757
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158]]
    seq的值为[4, 0, 4, 0, 4, 1],score的值为0.1466089890297302
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]]
    seq的值为[4, 0, 3, 0, 4, 0],score的值为0.1466089890297302
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145], [[4, 0, 3, 0, 4, 0, 4], 0.10162160739070145], [[4, 0, 4, 0, 4, 0, 2], 0.133527108126524], [[4, 0, 4, 0, 4, 1, 3], 0.13433645785738146], [[4, 0, 3, 0, 4, 0, 3], 0.13433645785738146], [[4, 0, 4, 0, 4, 1, 2], 0.1765132356615147], [[4, 0, 3, 0, 4, 0, 2], 0.1765132356615147], [[4, 0, 4, 0, 4, 0, 1], 0.17849538576316304], [[4, 0, 4, 0, 4, 1, 1], 0.2359580652480829], [[4, 0, 3, 0, 4, 0, 1], 0.2359580652480829], [[4, 0, 4, 0, 4, 0, 0], 0.2553691641356246], [[4, 0, 4, 0, 4, 1, 0], 0.33757967263878436], [[4, 0, 3, 0, 4, 0, 0], 0.33757967263878436]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4],score的值为0.07687377837246158
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808]]
    seq的值为[4, 0, 4, 0, 4, 0, 3],score的值为0.10162160739070145
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266]]
    seq的值为[4, 0, 4, 0, 4, 1, 4],score的值为0.10162160739070145
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 1, 4, 0], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 4, 2], 0.09255393852622307], [[4, 0, 4, 0, 4, 0, 3, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 1, 4, 1], 0.09311493701025386], [[4, 0, 4, 0, 4, 0, 3, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 1, 4, 2], 0.1223496516302921], [[4, 0, 4, 0, 4, 0, 4, 3], 0.12372357338469625], [[4, 0, 4, 0, 4, 0, 3, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 1, 4, 3], 0.16355366765708826], [[4, 0, 4, 0, 4, 0, 4, 4], 0.17700841612255808], [[4, 0, 4, 0, 4, 0, 3, 4], 0.23399239830392266], [[4, 0, 4, 0, 4, 1, 4, 4], 0.23399239830392266]]
    
    
    row的值为: [0.1 0.2 0.3 0.4 0.5]
    [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0],score的值为0.05328484273786184
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 1],score的值为0.07043873064683441
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]]
    seq的值为[4, 0, 4, 0, 4, 0, 3, 0],score的值为0.07043873064683441
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 3, 0, 4], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 0.06415350153917002], [[4, 0, 4, 0, 4, 0, 4, 1, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 3, 0, 3], 0.06454235605667437], [[4, 0, 4, 0, 4, 0, 4, 1, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 3, 0, 2], 0.08480631607002849], [[4, 0, 4, 0, 4, 0, 4, 0, 1], 0.08575864606040369], [[4, 0, 4, 0, 4, 0, 4, 1, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 3, 0, 1], 0.11336676360674905], [[4, 0, 4, 0, 4, 0, 4, 0, 0], 0.1226928845707327], [[4, 0, 4, 0, 4, 0, 4, 1, 0], 0.16219117115682374], [[4, 0, 4, 0, 4, 0, 3, 0, 0], 0.16219117115682374]]
    
    
    row的值为: [0.5 0.4 0.3 0.2 0.1]
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]] 3
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 4],score的值为0.03693423851032901
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 0, 3],score的值为0.04882440755007468
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789]]
    seq的值为[4, 0, 4, 0, 4, 0, 4, 1, 4],score的值为0.04882440755007468
    all_candidates的值为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
    排序之后的顺序为: [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 0], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 2], 0.0444678187149238], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 1], 0.04473735212737995], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 2], 0.05878325887761582], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 3], 0.059443363725407074], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 3], 0.07857985256322392], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 4], 0.08504422701497018], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 4], 0.11242235299906789], [[4, 0, 4, 0, 4, 0, 4, 1, 4, 4], 0.11242235299906789]]
    
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
     [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
     [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]
    

    1.1 贪心算法(k=1)

    def greedy_search(data):
        return [argmax(row) for row in data]
    
    greedy_search(data)
    
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
    

    2 viterbi算法

    要点:

    • 基于动态规划的思想
    • 每一步是根据上一步全部可能选择的最高概率推测当前所有选择的最高概率,保证有全局最优解
    • 适合搜索宽度较小的图,即每一步选择较少的时候
    • 时间复杂度为O(TNN)
    import numpy as np
    def viterbi(data):
        row,col = data.shape
        sigma = np.zeros((row,col))  # 存储最大的概率
        phi = np.zeros((row,col))    # 存储概率最大的索引
        sigma[0] = data[0]
        
        for i in range(1,row):
            for k in range(len(sigma[i-1])):
                tmp = float("-inf")
                for j in range(col):
                    if sigma[i-1][j]*data[i][k]>tmp:
                        tmp = sigma[i-1][j]*data[i][k]
                        index = j
    #                 tmp = max(tmp,sigma[i-1][k]*data[i][j])
                sigma[i][k] = tmp
                phi[i][k] = index
        # 回溯
        print(sigma)
        ans = [0]*row
        i_T = argmax(sigma[-1])
        i_t = i_T
        for t in range(row-2,-1,-1):
            i_t = int(phi[t+1][i_t])
            ans[t] = i_t
        ans[-1] = i_T
        return ans
            
    
    viterbi(data)
    
    [[1.000000e-01 2.000000e-01 3.000000e-01 4.000000e-01 5.000000e-01]
     [2.500000e-01 2.000000e-01 1.500000e-01 1.000000e-01 5.000000e-02]
     [2.500000e-02 5.000000e-02 7.500000e-02 1.000000e-01 1.250000e-01]
     [6.250000e-02 5.000000e-02 3.750000e-02 2.500000e-02 1.250000e-02]
     [6.250000e-03 1.250000e-02 1.875000e-02 2.500000e-02 3.125000e-02]
     [1.562500e-02 1.250000e-02 9.375000e-03 6.250000e-03 3.125000e-03]
     [1.562500e-03 3.125000e-03 4.687500e-03 6.250000e-03 7.812500e-03]
     [3.906250e-03 3.125000e-03 2.343750e-03 1.562500e-03 7.812500e-04]
     [3.906250e-04 7.812500e-04 1.171875e-03 1.562500e-03 1.953125e-03]
     [9.765625e-04 7.812500e-04 5.859375e-04 3.906250e-04 1.953125e-04]]
    
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
    
  • 相关阅读:
    override CreateParams events in delphi
    .NET下获取网页的内容的封装类
    WAYOS三天重启硬件版PCB和程序已设计完成,如果需要的人多就发去工厂统一制作
    WAYOS策略路由专用工具——进程端口自动扫描导入工具
    打造最专业的三天重启工具,本人再对WAYOS智能重启进行全面升级
    WAYOS BCM 机器+授权,全淘宝最低价,全新机器仅258还包邮费
    BCM路由全智能固件升级软件tftp,一键刷路由及常用固件下载
    WAYOS PPPOE用户数据定时备份并上传到FTP,保证数据不会因为掉配置、挂机等而丢失
    WAYOS内置免拉黑终于突破技术大关完美成功,以后再也不需要独立的电脑来运行免拉黑了
    巧用EasyRadius计费策略设置灵的计费费率,保证帐目一目了然
  • 原文地址:https://www.cnblogs.com/zhou-lin/p/15016429.html
Copyright © 2011-2022 走看看