zoukankan      html  css  js  c++  java
  • 强化学习-策略迭代代码实现

    1. 前言

    今天要重代码的角度给大家详细介绍下策略迭代的原理和实现方式。本节完整代码GitHub

    我们开始介绍策略迭代前,先介绍一个蛇棋的游戏

    image

    它是我们后面学习的环境,介绍下它的规则:

    1. 玩家每人拥有一个棋子,出发点在图中标为“1”的格子处。
    2. 依次掷骰子,根据骰子的点数将自己的棋子向前行进相应的步数。假设笔者的棋子在“1”处,并且投掷出“4”,则笔者的棋子就可以到达“5”的位置。
    3. 棋盘上有一些梯子,它的两边与棋盘上的两个格子相连。如果棋子落在其中一个格子上,就会自动走到梯子对应的另一个格子中。以图5-5所示的棋盘为例,如果笔者的棋子在“1”处,并且投掷出“2”,那么棋子将到达“3”处,由于此处有梯子,棋子将直接前进到梯子的另一段——“20”的位置。
    4. 最终的目标是到达“100”处,如果在到达时投掷的数字加上当前的位置超过了100,那么棋子将首先到达100,剩余的步数将反向前进。

    2. 蛇棋实现

    我们实现蛇棋的逻辑,应该集成gym的env,然后分别重写env下面的几个重要的接口,这样使用起来就可以和gym里面封装的小游戏一样了。

    class SnakeEnv(gym.Env):
        SIZE = 100
    
        def __init__(self, ladder_num, actions):
            """
            :param int ladder_num: 梯子的个数
            :param list actions: 可选择的行为
            """
            self.ladder_num = ladder_num
            self.actions = actions
            # 在整个范围内,随机生成梯子
            self.ladders = dict(np.random.randint(1, self.SIZE, size=(self.ladder_num, 2)))
            self.observation_space = Discrete(self.SIZE + 1)
            self.action_space = Discrete(len(actions))
    
            # 因为梯子是两个方向的,所以添加反方向的梯子
            new_ladders = {}
            for k, v in self.ladders.items():
                new_ladders[k] = v
                new_ladders[v] = k
            self.ladders = new_ladders
            self.pos = 1
    
        # 重置初始状态
        def reset(self):
            self.pos = 1
            return self.pos
    
        def step(self, action):
            """
            :param int action: 选择的行动
            :return: 下一个状态,奖励值,是否结束,其它内容
            """
            step = np.random.randint(1, self.actions[action] + 1)
            self.pos += step
            if self.pos == 100:
                return 100, 100, 1, {}
            elif self.pos > 100:
                self.pos = 200 - self.pos
    
            if self.pos in self.ladders:
                self.pos = self.ladders[self.pos]
            return self.pos, -1, 0, {}
    
        # 返回状态s的奖励值
        def reward(self, s):
            if s == 100:
                return 100
            else:
                return -1
    

    然后再实现一个我们自己的智能体agent,里面包含的东西有状态的奖励、策略、行动状态转移矩阵、状态值函数、状态行动值函数等。

    为了简单,我们用表格,或者矩阵的形式来表示各种变量。

    class TableAgent(object):
        def __init__(self, env):
            # 状态个数
            self.s_len = env.observation_space.n
            # 行动个数
            self.a_len = env.action_space.n
            # 每个状态的奖励,shape=[1,self.s_len]
            self.r = [env.reward(s) for s in range(0, self.s_len)]
            # 每个状态的行动策略,默认为0,shape=[1,self.s_len]
            self.pi = np.array([0 for s in range(0, self.s_len)])
            # 行动状态转移矩阵,shape=[self.a_len, self.s_len, self.s_len]
            self.p = np.zeros([self.a_len, self.s_len, self.s_len], dtype=np.float)
            # 梯子
            ladder_move = np.vectorize(lambda x: env.ladders[x] if x in env.ladders else x)
    
            # 计算状态s和行动a确定,下一个状态s'的概率
            for i, action in enumerate(env.actions):
                prob = 1.0 / action
                for src in range(1, 100):
                    step = np.arange(action)
                    step += src
                    step = np.piecewise(step, [step > 100, step <= 100],
                                        [lambda x: 200 - x, lambda x: x])
                    step = ladder_move(step)
                    for dst in step:
                        self.p[i, src, dst] += prob
    
            self.p[:, 100, 100] = 1
            # 状态值函数
            self.value_pi = np.zeros((self.s_len))
            # 状态行动值函数
            self.value_q = np.zeros((self.s_len, self.a_len))
            # 衰减因子
            self.gamma = 0.8
    

    3. 策略迭代实现

    前面我们已经介绍过了,策略迭代的过程可以分为2个步骤

    • 策略评估:策略评估时计算当前策略下,收敛的数据状态值函数。

    [v^T_{pi}(s_t)=sum_{a_t}pi^{T-1}(a_t|s_t)sum_{s_{t+1}}p(s_{t+1}|s_t,a_t)[r_{a_t}^{s_{t+1}} + gamma * v^{T-1}_{pi}(s_{t+1})];;;;;;(1) ]

    实现如下:

    # 策略评估
    def policy_evaluation(self, agent, max_iter=-1):
        """
        :param obj agent: 智能体
        :param int max_iter: 最大迭代数
        """
        iteration = 0
    
        while True:
            iteration += 1
            new_value_pi = agent.value_pi.copy()
            # 对每个state计算v(s)
            for i in range(1, agent.s_len):
                ac = agent.pi[i]
                transition = agent.p[ac, i, :]
                value_sa = np.dot(transition, agent.r + agent.gamma * agent.value_pi)
                new_value_pi[i] = value_sa
    
            # 前后2次值函数的变化小于一个阈值,结束
            diff = np.sqrt(np.sum(np.power(agent.value_pi - new_value_pi, 2)))
            if diff < 1e-6:
                break
            else:
                agent.value_pi = new_value_pi
            if iteration == max_iter:
                break
    
    • 策略提升:在计算出了收敛的状态值函数,再计算状态-行动值函数,再找出最好的策略。

    [v_{pi}(s_t)=sum_{a_t}pi(a_t|s_t)q_{pi}(s_t,a_t) ]

    [q_{pi}(s_t,a_t)=sum_{s_{t+1}}p(s_{t+1}|s_t,a_t)[r_{a_t}^{s_{t+1}} + gamma * v_{pi}(s_{t+1})] ]

    实现如下:

    # 策略提升
    def policy_improvement(self, agent):
        """
        :param obj agent: 智能体
        """
    
        # 初始化新策略
        new_policy = np.zeros_like(agent.pi)
        for i in range(1, agent.s_len):
            for j in range(0, agent.a_len):
                # 计算每一个状态行动值函数
                agent.value_q[i, j] = np.dot(agent.p[j, i, :], agent.r + agent.gamma * agent.value_pi)
    
            # 选出每个状态下的最优行动
            max_act = np.argmax(agent.value_q[i, :])
            new_policy[i] = max_act
        if np.all(np.equal(new_policy, agent.pi)):
            return False
        else:
            agent.pi = new_policy
            return True
    
    # 策略迭代
    def policy_iteration(self, agent):
        """
        :param obj agent: 智能体
        """
        iteration = 0
        while True:
            iteration += 1
            self.policy_evaluation(agent)
            ret = self.policy_improvement(agent)
            if not ret:
                break
        print('Iter {} rounds converge'.format(iteration))
    

    4. 总结

    我们通过学习了策略迭代的实现,能够比较清楚的看出强化学习的过程,策略迭代也是后面算法优化的一个基础。

  • 相关阅读:
    在Ubuntu中通过update-alternatives切换软件版本
    SCons: 替代 make 和 makefile 及 javac 的极好用的c、c++、java 构建工具
    mongodb 的使用
    利用grub从ubuntu找回windows启动项
    How to Repair GRUB2 When Ubuntu Won’t Boot
    Redis vs Mongo vs mysql
    java script 的工具
    python 的弹框
    how to use greendao in android studio
    python yield的终极解释
  • 原文地址:https://www.cnblogs.com/huangyc/p/10386466.html
Copyright © 2011-2022 走看看