Sarsa算法
实例代码
1 import numpy as np
2 import random
3 from collections import defaultdict
4 from environment import Env
5
6
7 # SARSA agent learns every time step from the sample <s, a, r, s', a'>
8 class SARSAgent:
9 def __init__(self, actions):
10 self.actions = actions
11 self.learning_rate = 0.01
12 self.discount_factor = 0.9
13 self.epsilon = 0.1
14 self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#动作值函数表,q表,要更新的表,不同于mc的更新v表
15
16 # with sample <s, a, r, s', a'>, learns new q function
17 def learn(self, state, action, reward, next_state, next_action):
18 current_q = self.q_table[state][action]
19 next_state_q = self.q_table[next_state][next_action]
20 new_q = (current_q + self.learning_rate *
21 (reward + self.discount_factor * next_state_q - current_q))#q表更新公式
22 self.q_table[state][action] = new_q
23
24 # get action for the state according to the q function table
25 # agent pick action of epsilon-greedy policy
26 def get_action(self, state):#获取下一步动作
27 #epsilon-greedy policy,exploration
28 if np.random.rand() < self.epsilon:
29 # take random action
30 action = np.random.choice(self.actions)
31 else:
32 # take action according to the q function table
33 state_action = self.q_table[state]
34 action = self.arg_max(state_action)
35 return action
36
37 @staticmethod
38 def arg_max(state_action):
39 max_index_list = []
40 max_value = state_action[0]
41 for index, value in enumerate(state_action):
42 if value > max_value:
43 max_index_list.clear()
44 max_value = value
45 max_index_list.append(index)
46 elif value == max_value:
47 max_index_list.append(index)
48 return random.choice(max_index_list)
49
50 if __name__ == "__main__":
51 env = Env()
52 agent = SARSAgent(actions=list(range(env.n_actions)))
53
54 for episode in range(1000):
55 # reset environment and initialize state
56
57 state = env.reset()
58 # get action of state from agent
59 action = agent.get_action(str(state))
60
61 while True:
62 env.render()
63
64 # take action and proceed one step in the environment
65 next_state, reward, done = env.step(action)
66 next_action = agent.get_action(str(next_state))
67
68 # with sample <s,a,r,s',a'>, agent learns new q function
69 agent.learn(str(state), action, reward, str(next_state), next_action)
70
71 state = next_state
72 action = next_action
73
74 # print q function of all states at screen
75 env.print_value_all(agent.q_table)
76
77 # if episode ends, then break
78 if done:
79 break