zoukankan      html  css  js  c++  java
  • [NLP] beam search的简单实现

    介绍

    本文基于该博客的内容改编。

    想象你是一位校长,手下有十个班级,每个班级5个学生,每个学生都坐在自己的座位上,每个人成绩都不一样。
    (如果你很难想象,那么就看看下面的代码实例中的data变量。)

    现在你的任务是,从一班到十班,根据特定规则,在每个班级寻找一位学生,然后搭配成为最好的学生组合(注意不是找最好的学生,而是学生组合),出道成为偶像,拯救学校的衰落(?)。

    由于有一套独特的评分规则,单纯的找出最好的学生并不是最佳的方案,而取决于不同学生组合而产生的最终分数,于是你采用了以下的策略:每次选择的时候,都会寻找前k个最佳组合。

    例如你进入了三班,那么手里也许已经有k个来自前一班和二班的不同组合。例如:[1,2],[2,3],[1,3],list 中的每个位置代表班级,数字代表学生,三个list,说明你有三套方案(k=3)。现在你要往这个表单中加入来自三班的新学生,5位同学都跃跃欲试,你把他们5个人,每个人都放在了三套方案的后面,那么也就是15套新方案(5个学生*三个备选方案)。此时计算每一个方案的分数,然后选择前k个方案(k=3)作为结果。

    最后你进入了四班,继续开始这套操作。

    最后就可以找出最优方案,以及它的备选。

    代码部分

    import numpy as np
    
    data = np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1],
            [0.1, 0.2, 0.3, 0.4, 0.5],
            [0.5, 0.4, 0.3, 0.2, 0.1]])
    
    def beam_search_decoder(data, k=3):
        # 第一步:初始化 seq
        # seq 是一个大list,最终会包含k个list,每一个list里面带有两个东西:序列(list) & 分数(int)
        sequences = [[list(),1.0]] 
    
        # 第二步: load 每一行
        for row in data: 
            # 2.1 初始化,到该行为止所有的可能
            all_candidates = list() 
            # 2.2 获得上一轮的结果,为了加入这一次的值
            for i in range(len(sequences)):
                seq, score = sequences[i] 
    
                # 2.3 将 row 中的所有结果 与sequences每一个(k)个进行运算,获得scores
                # 如果 seqences中有k个结果,一个row中有5个概率(label 或 该位置的可能符号),那么共产生 5*k个备选
                for j in range(len(row)):
                    # candidate = [ list + 每一个备选的index j, 新分数 ]
                    candidate = [seq + [j], score * -np.log(row[j])]
                    # 将生成好的新备选加入候补席位
                    all_candidates.append(candidate) # 加入备选
    
    
            # 根据分数进行排序
            ordered = sorted(all_candidates, key=lambda tup:tup[1])
    
            # 选择前k个,动态调整 seqences 的数量
            sequences = ordered[:k]
            # 查看每次sequences 的输出: print(sequences)
    
        return sequences
    
    def greedy_decoder(data):
        # greedy decoder
        # 每一行最大概率词的索引
        return [np.argmax(s) for s in data]
    
    # 数据准备阶段
    data = np.array(data)
    # 贪婪
    greedy_result = greedy_decoder(data)
    print("- 贪婪算法的结果:
    {}
    ".format(greedy_result))
    # beam search
    beam_result = beam_search_decoder(data)
    print("- beam 搜索的结果:
    {}
    ".format(beam_result))
    

    代码结果

    • 贪婪算法的结果:
      [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]

    • beam 搜索的结果:
      这里 k = 3,前面是结果,后面是分数
      [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
      [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
      [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]

  • 相关阅读:
    Java的数组的作业11月06日
    e课表项目第二次冲刺周期第九天
    e课表项目第二次冲刺周期第八天
    e课表项目第二次冲刺周期第七天
    e课表项目第二次冲刺周期第六天
    e课表项目第二次冲刺周期第五天
    e课表项目第二次冲刺周期第四天
    e课表项目第二次冲刺周期第三天
    e课表项目第二次冲刺周期第二天
    e课表项目第二次冲刺周期第一天
  • 原文地址:https://www.cnblogs.com/kykai/p/14033782.html
Copyright © 2011-2022 走看看