zoukankan      html  css  js  c++  java
  • 增强学习--蒙特卡洛方法

    蒙特卡洛方法

    实例代码

    下面代码是constant-α monte carlo,这里有一点介绍

      1 import numpy as np
      2 import random
      3 from collections import defaultdict
      4 from environment import Env
      5 
      6 
      7 # Monte Carlo Agent which learns every episodes from the sample
      8 class MCAgent:
      9     def __init__(self, actions):
     10         self.width = 5
     11         self.height = 5
     12         self.actions = actions
     13         self.learning_rate = 0.01
     14         self.discount_factor = 0.9
     15         self.epsilon = 0.1
     16         self.samples = []
     17         self.value_table = defaultdict(float)#初始化值函数表,0
     18 
     19     # append sample to memory(state, reward, done)
     20     def save_sample(self, state, reward, done):
     21         self.samples.append([state, reward, done])
     22 
     23     # for every episode, agent updates q function of visited states
     24     def update(self):
     25         G_t = 0
     26         visit_state = []
     27         for reward in reversed(self.samples):#此处reverse,状态反转
     28             state = str(reward[0])
     29             if state not in visit_state:#first-visit MC methods
     30                 visit_state.append(state)
     31                 G_t = self.discount_factor * (reward[1] + G_t)#累积回报
     32                 value = self.value_table[state]
     33                 self.value_table[state] = (value +
     34                                            self.learning_rate * (G_t - value))
     35                 #constant-α monte carlo constant-α蒙特卡洛值函数更新
     36 
     37     # get action for the state according to the q function table
     38     # agent pick action of epsilon-greedy policy
     39     def get_action(self, state):
     40         if np.random.rand() < self.epsilon:#以epsilon概率随机选择,Exploration
     41             # take random action
     42             action = np.random.choice(self.actions)
     43         else:
     44             # take action according to the q function table
     45             next_state = self.possible_next_state(state)
     46             action = self.arg_max(next_state)
     47         return int(action)
     48 
     49     # compute arg_max if multiple candidates exit, pick one randomly
     50     @staticmethod
     51     def arg_max(next_state):
     52         max_index_list = []
     53         max_value = next_state[0]
     54         for index, value in enumerate(next_state):
     55             if value > max_value:
     56                 max_index_list.clear()
     57                 max_value = value
     58                 max_index_list.append(index)
     59             elif value == max_value:
     60                 max_index_list.append(index)
     61         return random.choice(max_index_list)
     62 
     63     # get the possible next states
     64     def possible_next_state(self, state):
     65         col, row = state
     66         next_state = [0.0] * 4 #四个方向,Q(s,a)
     67 
     68         if row != 0:
     69             next_state[0] = self.value_table[str([col, row - 1])]
     70         else:
     71             next_state[0] = self.value_table[str(state)]
     72 
     73         if row != self.height - 1:
     74             next_state[1] = self.value_table[str([col, row + 1])]
     75         else:
     76             next_state[1] = self.value_table[str(state)]
     77 
     78         if col != 0:
     79             next_state[2] = self.value_table[str([col - 1, row])]
     80         else:
     81             next_state[2] = self.value_table[str(state)]
     82 
     83         if col != self.width - 1:
     84             next_state[3] = self.value_table[str([col + 1, row])]
     85         else:
     86             next_state[3] = self.value_table[str(state)]
     87 
     88         return next_state
     89 
     90 
     91 # main loop
     92 if __name__ == "__main__":
     93     env = Env()
     94     agent = MCAgent(actions=list(range(env.n_actions)))
     95 
     96     for episode in range(1000):#episode task
     97         import pdb; pdb.set_trace()
     98         state = env.reset()
     99         action = agent.get_action(state)
    100 
    101         while True:
    102             env.render()
    103 
    104             # forward to next state. reward is number and done is boolean
    105             next_state, reward, done = env.step(action)
    106             agent.save_sample(next_state, reward, done)
    107 
    108             # get next action
    109             action = agent.get_action(next_state)
    110 
    111             # at the end of each episode, update the q function table
    112             if done:
    113                 print("episode : ", episode)
    114                 agent.update()
    115                 agent.samples.clear()
    116                 break
  • 相关阅读:
    Vulkan Tutorial 06 逻辑设备与队列
    Vulkan Tutorial 05 物理设备与队列簇
    过滤器Filter(2)
    Filter过滤器(1)
    Java-Web监听器
    Filter案例
    <context:annotation-config/>
    @Autowired 和 @Resource
    Hibernate各种主键生成策略与配置详解
    eclipse新建workspace使用之前workspace的个性配置
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250103.html
Copyright © 2011-2022 走看看