zoukankan      html  css  js  c++  java
  • 强化学习-策略迭代代码实现

    1. 前言

    今天要重代码的角度给大家详细介绍下策略迭代的原理和实现方式。本节完整代码GitHub

    我们开始介绍策略迭代前,先介绍一个蛇棋的游戏

    image

    它是我们后面学习的环境,介绍下它的规则:

    1. 玩家每人拥有一个棋子,出发点在图中标为“1”的格子处。
    2. 依次掷骰子,根据骰子的点数将自己的棋子向前行进相应的步数。假设笔者的棋子在“1”处,并且投掷出“4”,则笔者的棋子就可以到达“5”的位置。
    3. 棋盘上有一些梯子,它的两边与棋盘上的两个格子相连。如果棋子落在其中一个格子上,就会自动走到梯子对应的另一个格子中。以图5-5所示的棋盘为例,如果笔者的棋子在“1”处,并且投掷出“2”,那么棋子将到达“3”处,由于此处有梯子,棋子将直接前进到梯子的另一段——“20”的位置。
    4. 最终的目标是到达“100”处,如果在到达时投掷的数字加上当前的位置超过了100,那么棋子将首先到达100,剩余的步数将反向前进。

    2. 蛇棋实现

    我们实现蛇棋的逻辑,应该集成gym的env,然后分别重写env下面的几个重要的接口,这样使用起来就可以和gym里面封装的小游戏一样了。

    class SnakeEnv(gym.Env):
        SIZE = 100
    
        def __init__(self, ladder_num, actions):
            """
            :param int ladder_num: 梯子的个数
            :param list actions: 可选择的行为
            """
            self.ladder_num = ladder_num
            self.actions = actions
            # 在整个范围内,随机生成梯子
            self.ladders = dict(np.random.randint(1, self.SIZE, size=(self.ladder_num, 2)))
            self.observation_space = Discrete(self.SIZE + 1)
            self.action_space = Discrete(len(actions))
    
            # 因为梯子是两个方向的,所以添加反方向的梯子
            new_ladders = {}
            for k, v in self.ladders.items():
                new_ladders[k] = v
                new_ladders[v] = k
            self.ladders = new_ladders
            self.pos = 1
    
        # 重置初始状态
        def reset(self):
            self.pos = 1
            return self.pos
    
        def step(self, action):
            """
            :param int action: 选择的行动
            :return: 下一个状态,奖励值,是否结束,其它内容
            """
            step = np.random.randint(1, self.actions[action] + 1)
            self.pos += step
            if self.pos == 100:
                return 100, 100, 1, {}
            elif self.pos > 100:
                self.pos = 200 - self.pos
    
            if self.pos in self.ladders:
                self.pos = self.ladders[self.pos]
            return self.pos, -1, 0, {}
    
        # 返回状态s的奖励值
        def reward(self, s):
            if s == 100:
                return 100
            else:
                return -1
    

    然后再实现一个我们自己的智能体agent,里面包含的东西有状态的奖励、策略、行动状态转移矩阵、状态值函数、状态行动值函数等。

    为了简单,我们用表格,或者矩阵的形式来表示各种变量。

    class TableAgent(object):
        def __init__(self, env):
            # 状态个数
            self.s_len = env.observation_space.n
            # 行动个数
            self.a_len = env.action_space.n
            # 每个状态的奖励,shape=[1,self.s_len]
            self.r = [env.reward(s) for s in range(0, self.s_len)]
            # 每个状态的行动策略,默认为0,shape=[1,self.s_len]
            self.pi = np.array([0 for s in range(0, self.s_len)])
            # 行动状态转移矩阵,shape=[self.a_len, self.s_len, self.s_len]
            self.p = np.zeros([self.a_len, self.s_len, self.s_len], dtype=np.float)
            # 梯子
            ladder_move = np.vectorize(lambda x: env.ladders[x] if x in env.ladders else x)
    
            # 计算状态s和行动a确定,下一个状态s'的概率
            for i, action in enumerate(env.actions):
                prob = 1.0 / action
                for src in range(1, 100):
                    step = np.arange(action)
                    step += src
                    step = np.piecewise(step, [step > 100, step <= 100],
                                        [lambda x: 200 - x, lambda x: x])
                    step = ladder_move(step)
                    for dst in step:
                        self.p[i, src, dst] += prob
    
            self.p[:, 100, 100] = 1
            # 状态值函数
            self.value_pi = np.zeros((self.s_len))
            # 状态行动值函数
            self.value_q = np.zeros((self.s_len, self.a_len))
            # 衰减因子
            self.gamma = 0.8
    

    3. 策略迭代实现

    前面我们已经介绍过了,策略迭代的过程可以分为2个步骤

    • 策略评估:策略评估时计算当前策略下,收敛的数据状态值函数。

    [v^T_{pi}(s_t)=sum_{a_t}pi^{T-1}(a_t|s_t)sum_{s_{t+1}}p(s_{t+1}|s_t,a_t)[r_{a_t}^{s_{t+1}} + gamma * v^{T-1}_{pi}(s_{t+1})];;;;;;(1) ]

    实现如下:

    # 策略评估
    def policy_evaluation(self, agent, max_iter=-1):
        """
        :param obj agent: 智能体
        :param int max_iter: 最大迭代数
        """
        iteration = 0
    
        while True:
            iteration += 1
            new_value_pi = agent.value_pi.copy()
            # 对每个state计算v(s)
            for i in range(1, agent.s_len):
                ac = agent.pi[i]
                transition = agent.p[ac, i, :]
                value_sa = np.dot(transition, agent.r + agent.gamma * agent.value_pi)
                new_value_pi[i] = value_sa
    
            # 前后2次值函数的变化小于一个阈值,结束
            diff = np.sqrt(np.sum(np.power(agent.value_pi - new_value_pi, 2)))
            if diff < 1e-6:
                break
            else:
                agent.value_pi = new_value_pi
            if iteration == max_iter:
                break
    
    • 策略提升:在计算出了收敛的状态值函数,再计算状态-行动值函数,再找出最好的策略。

    [v_{pi}(s_t)=sum_{a_t}pi(a_t|s_t)q_{pi}(s_t,a_t) ]

    [q_{pi}(s_t,a_t)=sum_{s_{t+1}}p(s_{t+1}|s_t,a_t)[r_{a_t}^{s_{t+1}} + gamma * v_{pi}(s_{t+1})] ]

    实现如下:

    # 策略提升
    def policy_improvement(self, agent):
        """
        :param obj agent: 智能体
        """
    
        # 初始化新策略
        new_policy = np.zeros_like(agent.pi)
        for i in range(1, agent.s_len):
            for j in range(0, agent.a_len):
                # 计算每一个状态行动值函数
                agent.value_q[i, j] = np.dot(agent.p[j, i, :], agent.r + agent.gamma * agent.value_pi)
    
            # 选出每个状态下的最优行动
            max_act = np.argmax(agent.value_q[i, :])
            new_policy[i] = max_act
        if np.all(np.equal(new_policy, agent.pi)):
            return False
        else:
            agent.pi = new_policy
            return True
    
    # 策略迭代
    def policy_iteration(self, agent):
        """
        :param obj agent: 智能体
        """
        iteration = 0
        while True:
            iteration += 1
            self.policy_evaluation(agent)
            ret = self.policy_improvement(agent)
            if not ret:
                break
        print('Iter {} rounds converge'.format(iteration))
    

    4. 总结

    我们通过学习了策略迭代的实现,能够比较清楚的看出强化学习的过程,策略迭代也是后面算法优化的一个基础。

  • 相关阅读:
    SpringBoot使用过滤器、拦截器、切面(AOP),及其之间的区别和执行顺序
    发送POST请求,包含文件MultipartFile参数,普通字符串参数,请求头参数
    Linux安装Mongodb(附带SpringBoot整合MongoDB项目Demo)
    博客目录
    Ubuntu+Hexo+Github搭建个人博客
    Hexo+Github搭建个人博客
    Linux设备驱动程序学习----3.模块的编译和装载
    Linux设备驱动程序学习----2.内核模块与应用程序的对比
    Linux设备驱动程序学习----1.设备驱动程序简介
    Linux设备驱动程序学习----目录
  • 原文地址:https://www.cnblogs.com/huangyc/p/10386466.html
Copyright © 2011-2022 走看看