zoukankan      html  css  js  c++  java
  • 增强学习--策略迭代

    策略迭代

    实例代码

     1 class PolicyIteration:
     2     def __init__(self, env):
     3         self.env = env
     4         # 2-d list for the value function
     5         self.value_table = [[0.0] * env.width for _ in range(env.height)]#值函数表
     6         # list of random policy (same probability of up, down, left, right)
     7         self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width
     8                                     for _ in range(env.height)]#每一状态的动作策略表,一开始向四方运动是相同概率的
     9         # setting terminal state
    10         self.policy_table[2][2] = []#吸收态,终止
    11         self.discount_factor = 0.9
    12 
    13     def policy_evaluation(self):#策略估计
    14         next_value_table = [[0.00] * self.env.width
    15                                     for _ in range(self.env.height)]
    16 
    17         # Bellman Expectation Equation for the every states
    18         for state in self.env.get_all_states():
    19             value = 0.0
    20             # keep the value function of terminal states as 0(吸收态赋0)
    21             if state == [2, 2]:
    22                 next_value_table[state[0]][state[1]] = value
    23                 continue
    24 
    25             for action in self.env.possible_actions:#计算所有可能动作
    26                 next_state = self.env.state_after_action(state, action)
    27                 reward = self.env.get_reward(state, action)
    28                 next_value = self.get_value(next_state)
    29                 value += (self.get_policy(state)[action] *
    30                           (reward + self.discount_factor * next_value))
    31 
    32             next_value_table[state[0]][state[1]] = round(value, 2)
    33 
    34         self.value_table = next_value_table
    35 
    36     def policy_improvement(self):#策略改进
    37         next_policy = self.policy_table
    38         for state in self.env.get_all_states():
    39             if state == [2, 2]:
    40                 continue
    41             value = -99999
    42             max_index = []
    43             result = [0.0, 0.0, 0.0, 0.0]  # initialize the policy
    44 
    45             # for every actions, calculate 计算所有可能动作,保留取得最大值函数的动作
    46             # [reward + (discount factor) * (next state value function)]
    47             for index, action in enumerate(self.env.possible_actions):
    48                 next_state = self.env.state_after_action(state, action)
    49                 reward = self.env.get_reward(state, action)
    50                 next_value = self.get_value(next_state)
    51                 temp = reward + self.discount_factor * next_value
    52 
    53                 # We normally can't pick multiple actions in greedy policy.
    54                 # but here we allow multiple actions with same max values 允许多个取最大值函数的动作存在
    55                 if temp == value:
    56                     max_index.append(index)
    57                 elif temp > value:
    58                     value = temp
    59                     max_index.clear()
    60                     max_index.append(index)
    61 
    62             # probability of action
    63             prob = 1 / len(max_index)
    64 
    65             for index in max_index:
    66                 result[index] = prob
    67 
    68             next_policy[state[0]][state[1]] = result#更新策略表
    69 
    70         self.policy_table = next_policy
    71 
    72     # get action according to the current policy
    73     def get_action(self, state):
    74         random_pick = random.randrange(100) / 100
    75 
    76         policy = self.get_policy(state)
    77         policy_sum = 0.0
    78         # return the action in the index
    79         for index, value in enumerate(policy):
    80             policy_sum += value
    81             if random_pick < policy_sum:
    82                 return index
    83 
    84     # get policy of specific state
    85     def get_policy(self, state):
    86         if state == [2, 2]:
    87             return 0.0
    88         return self.policy_table[state[0]][state[1]]
    89 
    90     def get_value(self, state):
    91         return round(self.value_table[state[0]][state[1]], 2)
  • 相关阅读:
    work_27_一次springBoot+orcal+Mabits PageHele的使用
    work_26_swagger2整合springBoot和使用
    work_25_docker--RabbitMq消息队列
    work_24_MYSQL从create table... 到分库分表
    work_23_常用的工具类
    work_22_MySQL分库分表的初识
    work_21_AtomicInteger API
    work_20_stream的使用
    MySQL 基础语句的练习2
    MySQL 基础语句的练习
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250082.html
Copyright © 2011-2022 走看看