zoukankan      html  css  js  c++  java
  • 增强学习--Q-leraning

    Q-learning

    实例代码

     1 import numpy as np
     2 import random
     3 from environment import Env
     4 from collections import defaultdict
     5 
     6 class QLearningAgent:
     7     def __init__(self, actions):
     8         # actions = [0, 1, 2, 3]
     9         self.actions = actions
    10         self.learning_rate = 0.01
    11         self.discount_factor = 0.9
    12         self.epsilon = 0.1
    13         self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#待更新q表
    14 
    15     # update q function with sample <s, a, r, s'>
    16     def learn(self, state, action, reward, next_state):
    17         current_q = self.q_table[state][action]
    18         # using Bellman Optimality Equation to update q function
    19         new_q = reward + self.discount_factor * max(self.q_table[next_state])
    20         self.q_table[state][action] += self.learning_rate * (new_q - current_q)#更新公式,off-policy
    21 
    22     # get action for the state according to the q function table
    23     # agent pick action of epsilon-greedy policy
    24     def get_action(self, state):
    25         #epsilon-greedy policy
    26         if np.random.rand() < self.epsilon:
    27             # take random action
    28             action = np.random.choice(self.actions)
    29         else:
    30             # take action according to the q function table
    31             state_action = self.q_table[state]
    32             action = self.arg_max(state_action)
    33         return action
    34 
    35     @staticmethod
    36     def arg_max(state_action):
    37         max_index_list = []
    38         max_value = state_action[0]
    39         for index, value in enumerate(state_action):
    40             if value > max_value:
    41                 max_index_list.clear()
    42                 max_value = value
    43                 max_index_list.append(index)
    44             elif value == max_value:
    45                 max_index_list.append(index)
    46         return random.choice(max_index_list)
    47 
    48 if __name__ == "__main__":
    49     env = Env()
    50     agent = QLearningAgent(actions=list(range(env.n_actions)))
    51 
    52     for episode in range(1000):
    53         state = env.reset()
    54 
    55         while True:
    56             env.render()
    57 
    58             # take action and proceed one step in the environment
    59             action = agent.get_action(str(state))
    60             next_state, reward, done = env.step(action)
    61 
    62             # with sample <s,a,r,s'>, agent learns new q function
    63             agent.learn(str(state), action, reward, str(next_state))
    64 
    65             state = next_state
    66             env.print_value_all(agent.q_table)
    67 
    68             # if episode ends, then break
    69             if done:
    70                 break
  • 相关阅读:
    专业实训项目需求分析
    2015年秋季个人阅读计划
    场景调研
    二维数组最大连通子数组
    单元测试
    《大道至简——软件工程实践者的思想》阅读笔记之三
    《大道至简——软件工程实践者的思想》阅读笔记之二
    人机交互-输入法使用评价
    第一阶段个人总结10
    第一阶段个人总结09
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250121.html
Copyright © 2011-2022 走看看