zoukankan      html  css  js  c++  java
  • 强化学习实战(1):gridworld

    参考:https://orzyt.cn/posts/gridworld/

    Reinforcement Learning: An Introduction》在第三章中给出了一个简单的例子:Gridworld, 以帮助我们理解finite MDPs,

    同时也求解了该问题的贝尔曼期望方程贝尔曼最优方程. 本文简要说明如何进行编程求解.

    问题

    下图用一个矩形网格展示了一个简单finite MDP - Gridworld.
    网格中的每一格对应于environment的一个state.
    在每一格, 有四种可能的actions:上/下/左/右, 对应于agent往相应的方向移动一个单元格.
    使agent离开网格的actions会使得agent留在原来的位置, 但是会有一个值为-1的reward.
    除了那些使得agent离开state Astate B的action, 其他的actions对应的reward都是0.
    处在state A时, 所有的actions会有值为+10的reward, 并且agent会移动到state A'.
    处在state B时, 所有的actions会有值为+5的reward, 并且agent会移动到state B'.

    元素

    • 状态(State): 网格的坐标, 共 $5 imes 5 = 25$ 个状态;
    • 动作(Action): 上/下/左/右四种动作;
    • 策略(Policy): $pi(a | s) = frac14 ; ext{for} ; forall s in S, ext{and} ; forall ; a in {↑,↓,←,→}$;
    • 奖励(Reward): 如题所述;
    • 折扣因子(Discount rate): $gamma in [0, 1]$, 本文采用 $gamma=0.9$。

    目标

    • 使用贝尔曼期望方程, 求解给定随机策略 $pi(a | s) = frac14$ 下的状态值函数.
    • 使用贝尔曼最优方程, 求解最优状态值函数.

    实现

      1 import numpy as np
      2 
      3 %matplotlib inline
      4 import matplotlib
      5 import matplotlib.pyplot as plt
      6 from matplotlib.table import Table
      7 
      8 #定义grid问题中常用的变量
      9 # 格子尺寸
     10 WORLD_SIZE = 5
     11 # 状态A的位置(下标从0开始,下同)
     12 A_POS = [0, 1]
     13 # 状态A'的位置
     14 A_PRIME_POS = [4, 1]
     15 # 状态B的位置
     16 B_POS = [0, 3]
     17 # 状态B'的位置
     18 B_PRIME_POS = [2, 3]
     19 # 折扣因子
     20 DISCOUNT = 0.9
     21 
     22 # 动作集={上,下,左,右}
     23 ACTIONS = [np.array([-1, 0]),
     24            np.array([1, 0]),
     25            np.array([0, 1]),
     26            np.array([0, -1]),
     27 ]
     28 # 策略,每个动作等概率
     29 ACTION_PROB = 0.25
     30 
     31 
     32 #绘图相关函数
     33 def draw_image(image):
     34     fig, ax = plt.subplots()
     35     ax.set_axis_off()
     36     tb = Table(ax, bbox=[0, 0, 1, 1])
     37 
     38     nrows, ncols = image.shape
     39     width, height = 1.0 / ncols, 1.0 / nrows
     40 
     41     # 添加表格
     42     for (i,j), val in np.ndenumerate(image):
     43         tb.add_cell(i, j, width, height, text=val, 
     44                     loc='center', facecolor='white')
     45 
     46     # 行标签
     47     for i, label in enumerate(range(len(image))):
     48         tb.add_cell(i, -1, width, height, text=label+1, loc='right', 
     49                     edgecolor='none', facecolor='none')
     50     # 列标签
     51     for j, label in enumerate(range(len(image))):
     52         tb.add_cell(WORLD_SIZE, j, width, height/2, text=label+1, loc='center', 
     53                            edgecolor='none', facecolor='none')
     54     ax.add_table(tb)
     55 
     56 
     57 
     58 def step(state, action):
     59     '''给定当前状态以及采取的动作,返回后继状态及其立即奖励
     60     
     61     Parameters
     62     ----------
     63     state : list
     64         当前状态
     65     action : list
     66         采取的动作
     67     
     68     Returns
     69     -------
     70     tuple
     71         后继状态,立即奖励
     72         
     73     '''
     74     # 如果当前位置为状态A,则直接跳到状态A',奖励为+10
     75     if state == A_POS:
     76         return A_PRIME_POS, 10
     77     # 如果当前位置为状态B,则直接跳到状态B',奖励为+5
     78     if state == B_POS:
     79         return B_PRIME_POS, 5
     80 
     81     state = np.array(state)
     82     # 通过坐标运算得到后继状态
     83     next_state = (state + action).tolist()
     84     x, y = next_state
     85     # 根据后继状态的坐标判断是否出界
     86     if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE:
     87         # 出界则待在原地,奖励为-1
     88         reward = -1.0
     89         next_state = state
     90     else:
     91         # 未出界则奖励为0
     92         reward = 0
     93     return next_state, reward
     94 
     95 
     96 a
     97 π(a|s)[r+γ
     98 v
     99 π
    100 (
    101 s
    102 103 )]
    104 vπ=∑aπ(a|s)[r+γvπ(s′)]
    105 In [5]:
    106 def bellman_equation():
    107     ''' 求解贝尔曼(期望)方程
    108     '''
    109     # 初始值函数
    110     value = np.zeros((WORLD_SIZE, WORLD_SIZE))
    111     while True:
    112         new_value = np.zeros(value.shape)
    113         # 遍历所有状态
    114         for i in range(0, WORLD_SIZE):
    115             for j in range(0, WORLD_SIZE):
    116                 # 遍历所有动作
    117                 for action in ACTIONS:
    118                     # 执行动作,转移到后继状态,并获得立即奖励
    119                     (next_i, next_j), reward = step([i, j], action)
    120                     # 贝尔曼期望方程
    121                     new_value[i, j] += ACTION_PROB * 
    122                     (reward + DISCOUNT * value[next_i, next_j])
    123         # 迭代终止条件: 误差小于1e-4
    124         if np.sum(np.abs(value - new_value)) < 1e-4:
    125             draw_image(np.round(new_value, decimals=2))
    126             plt.title('$v_{pi}$')
    127             plt.show()
    128             plt.close()
    129             break
    130         value = new_value
    131 
    132 def bellman_optimal_equation():
    133     '''求解贝尔曼最优方程
    134     '''
    135     # 初始值函数
    136     value = np.zeros((WORLD_SIZE, WORLD_SIZE))
    137     while True:
    138         new_value = np.zeros(value.shape)
    139         # 遍历所有状态
    140         for i in range(0, WORLD_SIZE):
    141             for j in range(0, WORLD_SIZE):
    142                 values = []
    143                 # 遍历所有动作
    144                 for action in ACTIONS:
    145                     # 执行动作,转移到后继状态,并获得立即奖励
    146                     (next_i, next_j), reward = step([i, j], action)
    147                     # 缓存动作值函数 q(s,a) = r + γ*v(s')
    148                     values.append(reward + DISCOUNT * value[next_i, next_j])
    149                 # 根据贝尔曼最优方程,找出最大的动作值函数 q(s,a) 进行更新
    150                 new_value[i, j] = np.max(values)
    151         # 迭代终止条件: 误差小于1e-4
    152         if np.sum(np.abs(new_value - value)) < 1e-4:
    153             draw_image(np.round(new_value, decimals=2))
    154             plt.title('$v_{*}$')
    155             plt.show()
    156             plt.close()
    157             break
    158         value = new_value
    159 
    160 
    161 bellman_equation()
    162 
    163 bellman_optimal_equation()
  • 相关阅读:
    基于Enterprise Library 6 的AOP实现
    命行下的查询与替换字符串
    软件架构中质量特性
    【redis】突然流量增大,不定时挂死排障记录
    Heritrix 3.1.0 源码解析(二)
    Apache Jackrabbit源码研究(四)
    Heritrix 3.1.0 源码解析(三)
    Apache Jackrabbit源码研究(五)
    Heritrix 3.1.0 源码解析(一)
    JVM 自定义的类加载器的实现和使用
  • 原文地址:https://www.cnblogs.com/feifanrensheng/p/13423006.html
Copyright © 2011-2022 走看看