zoukankan      html  css  js  c++  java
  • 增强学习--Q-leraning

    Q-learning

    实例代码

     1 import numpy as np
     2 import random
     3 from environment import Env
     4 from collections import defaultdict
     5 
     6 class QLearningAgent:
     7     def __init__(self, actions):
     8         # actions = [0, 1, 2, 3]
     9         self.actions = actions
    10         self.learning_rate = 0.01
    11         self.discount_factor = 0.9
    12         self.epsilon = 0.1
    13         self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#待更新q表
    14 
    15     # update q function with sample <s, a, r, s'>
    16     def learn(self, state, action, reward, next_state):
    17         current_q = self.q_table[state][action]
    18         # using Bellman Optimality Equation to update q function
    19         new_q = reward + self.discount_factor * max(self.q_table[next_state])
    20         self.q_table[state][action] += self.learning_rate * (new_q - current_q)#更新公式,off-policy
    21 
    22     # get action for the state according to the q function table
    23     # agent pick action of epsilon-greedy policy
    24     def get_action(self, state):
    25         #epsilon-greedy policy
    26         if np.random.rand() < self.epsilon:
    27             # take random action
    28             action = np.random.choice(self.actions)
    29         else:
    30             # take action according to the q function table
    31             state_action = self.q_table[state]
    32             action = self.arg_max(state_action)
    33         return action
    34 
    35     @staticmethod
    36     def arg_max(state_action):
    37         max_index_list = []
    38         max_value = state_action[0]
    39         for index, value in enumerate(state_action):
    40             if value > max_value:
    41                 max_index_list.clear()
    42                 max_value = value
    43                 max_index_list.append(index)
    44             elif value == max_value:
    45                 max_index_list.append(index)
    46         return random.choice(max_index_list)
    47 
    48 if __name__ == "__main__":
    49     env = Env()
    50     agent = QLearningAgent(actions=list(range(env.n_actions)))
    51 
    52     for episode in range(1000):
    53         state = env.reset()
    54 
    55         while True:
    56             env.render()
    57 
    58             # take action and proceed one step in the environment
    59             action = agent.get_action(str(state))
    60             next_state, reward, done = env.step(action)
    61 
    62             # with sample <s,a,r,s'>, agent learns new q function
    63             agent.learn(str(state), action, reward, str(next_state))
    64 
    65             state = next_state
    66             env.print_value_all(agent.q_table)
    67 
    68             # if episode ends, then break
    69             if done:
    70                 break
  • 相关阅读:
    拷贝目录下文件,但某种类型文件例外
    编译个性化的openwrt固件
    -exec和|xargs
    OpenMP多线程linux下的使用,简单化
    clock_gettime的使用,计时比clock()精确
    openvswitch安装和使用 --修订通用教程的一些错误
    树莓派配置AP模式
    win7下的mstsc ubuntu下的rdesktop
    微信小程序-商品列表左=>右联动
    Vue.js最佳实践(五招让你成为Vue.js大师)
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250121.html
Copyright © 2011-2022 走看看