zoukankan      html  css  js  c++  java
  • [RL學習篇][#3] 自動學習grid_mdp最佳的策略

    本文修改 policy_iteration.py程式,讓他可以執行[#1]的程式,並找出最佳動作。

     1 # /bin/python
     2 import numpy;
     3 import random;
     4 import gym;
     5 #from grid_mdp import Grid_Mdp
     6 
     7 
     8 class Policy_Value:
     9     def __init__(self, grid_mdp):
    10         self.v = [0.0 for i in range(len(grid_mdp.env.states) + 1)] # 初始變數v <-- 值函數
    11 
    12         self.pi = dict()
    13         for state in grid_mdp.env.states:
    14             if state in grid_mdp.env.terminate_states: continue
    15             self.pi[state] = grid_mdp.env.action_s[0] #初始pi <-- 策略pi
    16 
    17     def policy_improve(self, grid_mdp):
    18 
    19         for state in grid_mdp.env.states:
    20             grid_mdp.env.setAction(state)  # upate state
    21             if state in grid_mdp.env.terminate_states: continue
    22 
    23             a1 = grid_mdp.env.action_s[0]
    24             s, r, t, z = grid_mdp.env._step(a1)
    25             v1 = r + grid_mdp.env.gamma * self.v[s]
    26 
    27             for action in grid_mdp.env.action_s:
    28                 s, r, t, z = grid_mdp.env._step(action)
    29                 if v1 < r + grid_mdp.env.gamma * self.v[s]: # 當action有更好的值,則更新動作
    30                     a1 = action
    31                     v1 = r + grid_mdp.env.gamma * self.v[s]
    32 
    33             self.pi[state] = a1   # 紀錄最佳動作
    34 
    35     def policy_evaluate(self, grid_mdp):
    36         for i in range(1000):
    37             delta = 0.0
    38             for state in grid_mdp.env.states:
    39                 grid_mdp.env.setAction(state) # upate state
    40                 if state in grid_mdp.env.terminate_states: continue
    41                 action = self.pi[state]
    42 
    43                 s, r, t, z = grid_mdp.env.step(action)
    44                 new_v = r + grid_mdp.env.gamma * self.v[s]
    45                 delta += abs(self.v[state] - new_v)
    46                 self.v[state] = new_v
    47 
    48             if delta < 1e-6:
    49                 break;
    50 
    51     def policy_iterate(self, grid_mdp):
    52         for i in range(100):
    53             self.policy_evaluate(grid_mdp);
    54             self.policy_improve(grid_mdp);
    55 
    56 
    57 if __name__ == "__main__":
    58     #grid_mdp = Grid_Mdp()
    59     env = gym.make('GridWorld-v0')
    60 
    61     policy_value = Policy_Value(env)
    62     policy_value.policy_iterate(env)
    63     print("value:")
    64     for i in range(1, 6):
    65         print("%d:%f	" % (i, policy_value.v[i]), )
    66     print("")
    67 
    68     print("policy:")
    69     for i in range(1, 6):
    70         print("%d->%s	" % (i, policy_value.pi[i]), )
    71     print("")

    執行結果如下:

    -----------------------------------------------------------------------------------------------------------------------------------------------------

    /home/lsa-dla/anaconda3/envs/tensorflow/bin/python /home/lsa-dla/PycharmProjects/grid_mdp/lsa_test2.py
    WARN: Environment '<class 'gym.envs.classic_control.grid_mdp.GridEnv'>' has deprecated methods. Compatibility code invoked.
    value:
    1:0.640000
    2:0.800000
    3:1.000000
    4:0.800000
    5:0.640000

    policy:
    1->e
    2->e
    3->s
    4->w
    5->w


    Process finished with exit code 0

     ------------------------------------------------------------------------------------------------------------------------------------------------------

    reference:

    [1]  Reinforcement_Learning_Blog/2.强化学习系列之二:模型相关的强化学习/

  • 相关阅读:
    Java中last_insert_id的使用
    Java上传视频
    Java创建Excel-DEMO
    导出excel表格
    Java导入excel并保存到数据库
    Java基础13一异常
    Java基础12一IO流
    腾讯云-等保要求
    云安全等保项目服务内容及云安全产品清单-(腾讯云部分)
    《网络风险及网络安全》培训总结
  • 原文地址:https://www.cnblogs.com/lishyhan/p/9052816.html
Copyright © 2011-2022 走看看