值迭代
实例代码
![](https://img2018.cnblogs.com/blog/756329/201901/756329-20190110161323908-829565007.png)
![](https://img2018.cnblogs.com/blog/756329/201905/756329-20190523173606099-439896387.png)
1 class ValueIteration:
2 def __init__(self, env):
3 self.env = env
4 # 2-d list for the value function
5 self.value_table = [[0.0] * env.width for _ in range(env.height)]
6 self.discount_factor = 0.9
7
8 # get next value function table from the current value function table
9 def value_iteration(self):
10 next_value_table = [[0.0] * self.env.width
11 for _ in range(self.env.height)]
12 for state in self.env.get_all_states():
13 if state == [2, 2]:
14 next_value_table[state[0]][state[1]] = 0.0
15 continue
16 value_list = []
17
18 for action in self.env.possible_actions:
19 next_state = self.env.state_after_action(state, action)
20 reward = self.env.get_reward(state, action)
21 next_value = self.get_value(next_state)
22 value_list.append((reward + self.discount_factor * next_value))
23 # return the maximum value(it is the optimality equation!!)
24 next_value_table[state[0]][state[1]] = round(max(value_list), 2)#每一次更新值函数表时取最大回报的动作更新
25 self.value_table = next_value_table
26
27 # get action according to the current value function table
28 def get_action(self, state):
29 import pdb; pdb.set_trace()
30 action_list = []
31 max_value = -99999
32
33 if state == [2, 2]:
34 return []
35
36 # calculating q values for the all actions and
37 # append the action to action list which has maximum q value
38 for action in self.env.possible_actions:
39
40 next_state = self.env.state_after_action(state, action)
41 reward = self.env.get_reward(state, action)
42 next_value = self.get_value(next_state)
43 value = (reward + self.discount_factor * next_value)
44
45 if value > max_value:
46 action_list.clear()
47 action_list.append(action)
48 max_value = value
49 elif value == max_value:
50 action_list.append(action)
51
52 return action_list
53
54 def get_value(self, state):
55 return round(self.value_table[state[0]][state[1]], 2)