zoukankan      html  css  js  c++  java
  • 增强学习--Sarsa算法

    Sarsa算法

    实例代码

     1 import numpy as np
     2 import random
     3 from collections import defaultdict
     4 from environment import Env
     5 
     6 
     7 # SARSA agent learns every time step from the sample <s, a, r, s', a'>
     8 class SARSAgent:
     9     def __init__(self, actions):
    10         self.actions = actions
    11         self.learning_rate = 0.01
    12         self.discount_factor = 0.9
    13         self.epsilon = 0.1
    14         self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#动作值函数表,q表,要更新的表,不同于mc的更新v表
    15 
    16     # with sample <s, a, r, s', a'>, learns new q function
    17     def learn(self, state, action, reward, next_state, next_action):
    18         current_q = self.q_table[state][action]
    19         next_state_q = self.q_table[next_state][next_action]
    20         new_q = (current_q + self.learning_rate *
    21                 (reward + self.discount_factor * next_state_q - current_q))#q表更新公式
    22         self.q_table[state][action] = new_q
    23 
    24     # get action for the state according to the q function table
    25     # agent pick action of epsilon-greedy policy
    26     def get_action(self, state):#获取下一步动作
    27         #epsilon-greedy policy,exploration
    28         if np.random.rand() < self.epsilon:
    29             # take random action
    30             action = np.random.choice(self.actions)
    31         else:
    32             # take action according to the q function table
    33             state_action = self.q_table[state]
    34             action = self.arg_max(state_action)
    35         return action
    36 
    37     @staticmethod
    38     def arg_max(state_action):
    39         max_index_list = []
    40         max_value = state_action[0]
    41         for index, value in enumerate(state_action):
    42             if value > max_value:
    43                 max_index_list.clear()
    44                 max_value = value
    45                 max_index_list.append(index)
    46             elif value == max_value:
    47                 max_index_list.append(index)
    48         return random.choice(max_index_list)
    49 
    50 if __name__ == "__main__":
    51     env = Env()
    52     agent = SARSAgent(actions=list(range(env.n_actions)))
    53 
    54     for episode in range(1000):
    55         # reset environment and initialize state
    56 
    57         state = env.reset()
    58         # get action of state from agent
    59         action = agent.get_action(str(state))
    60 
    61         while True:
    62             env.render()
    63 
    64             # take action and proceed one step in the environment
    65             next_state, reward, done = env.step(action)
    66             next_action = agent.get_action(str(next_state))
    67 
    68             # with sample <s,a,r,s',a'>, agent learns new q function
    69             agent.learn(str(state), action, reward, str(next_state), next_action)
    70 
    71             state = next_state
    72             action = next_action
    73 
    74             # print q function of all states at screen
    75             env.print_value_all(agent.q_table)
    76 
    77             # if episode ends, then break
    78             if done:
    79                 break
  • 相关阅读:
    数组中的趣味题二
    数组中的趣味题一
    归并排序
    堆内存与栈内存
    c++中的继承和组合
    直接插入排序
    NYOJ 1067 Compress String(区间dp)
    C++ Primer 学习笔记与思考_7 void和void*指针的使用方法
    ucgui界面设计演示样例2
    手机无法连接电脑的手机助手
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250114.html
Copyright © 2011-2022 走看看