zoukankan      html  css  js  c++  java
  • 深度增强学习--Actor Critic

    Actor Critic value-based和policy-based的结合

    实例代码

      1 import sys
      2 import gym
      3 import pylab
      4 import numpy as np
      5 from keras.layers import Dense
      6 from keras.models import Sequential
      7 from keras.optimizers import Adam
      8 
      9 EPISODES = 1000
     10 
     11 
     12 # A2C(Advantage Actor-Critic) agent for the Cartpole
     13 # actor-critic算法结合了value-based和policy-based方法
     14 class A2CAgent:
     15     def __init__(self, state_size, action_size):
     16         # if you want to see Cartpole learning, then change to True
     17         self.render = True
     18         self.load_model = False
     19         # get size of state and action
     20         self.state_size = state_size
     21         self.action_size = action_size
     22         self.value_size = 1
     23 
     24         # These are hyper parameters for the Policy Gradient
     25         self.discount_factor = 0.99
     26         self.actor_lr = 0.001
     27         self.critic_lr = 0.005
     28 
     29         # create model for policy network
     30         self.actor = self.build_actor()
     31         self.critic = self.build_critic()
     32 
     33         if self.load_model:
     34             self.actor.load_weights("./save_model/cartpole_actor.h5")
     35             self.critic.load_weights("./save_model/cartpole_critic.h5")
     36 
     37     # approximate policy and value using Neural Network
     38     # actor: state is input and probability of each action is output of model
     39     def build_actor(self):#actor网络:state-->action
     40         actor = Sequential()
     41         actor.add(Dense(24, input_dim=self.state_size, activation='relu',
     42                         kernel_initializer='he_uniform'))
     43         actor.add(Dense(self.action_size, activation='softmax',
     44                         kernel_initializer='he_uniform'))
     45         actor.summary()
     46         # See note regarding crossentropy in cartpole_reinforce.py
     47         actor.compile(loss='categorical_crossentropy',
     48                       optimizer=Adam(lr=self.actor_lr))
     49         return actor
     50 
     51     # critic: state is input and value of state is output of model
     52     def build_critic(self):#critic网络:state-->value,Q值
     53         critic = Sequential()
     54         critic.add(Dense(24, input_dim=self.state_size, activation='relu',
     55                          kernel_initializer='he_uniform'))
     56         critic.add(Dense(self.value_size, activation='linear',
     57                          kernel_initializer='he_uniform'))
     58         critic.summary()
     59         critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
     60         return critic
     61 
     62     # using the output of policy network, pick action stochastically
     63     def get_action(self, state):
     64         policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作
     65         return np.random.choice(self.action_size, 1, p=policy)[0]
     66 
     67     # update policy network every episode
     68     def train_model(self, state, action, reward, next_state, done):
     69         target = np.zeros((1, self.value_size))#(1,1)
     70         advantages = np.zeros((1, self.action_size))#(1, 2)
     71 
     72         value = self.critic.predict(state)[0]#critic网络预测的当前q值
     73         next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值
     74 
     75         '''
     76         理解下面部分
     77         '''
     78         if done:
     79             advantages[0][action] = reward - value
     80             target[0][0] = reward
     81         else:
     82             advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络
     83             target[0][0] = reward + self.discount_factor * next_value#critic网络
     84 
     85         self.actor.fit(state, advantages, epochs=1, verbose=0)
     86         self.critic.fit(state, target, epochs=1, verbose=0)
     87 
     88 
     89 if __name__ == "__main__":
     90     # In case of CartPole-v1, maximum length of episode is 500
     91     env = gym.make('CartPole-v1')
     92     # get size of state and action from environment
     93     state_size = env.observation_space.shape[0]
     94     action_size = env.action_space.n
     95 
     96     # make A2C agent
     97     agent = A2CAgent(state_size, action_size)
     98     scores, episodes = [], []
     99 
    100     for e in range(EPISODES):
    101         done = False
    102         score = 0
    103         state = env.reset()
    104         state = np.reshape(state, [1, state_size])
    105 
    106         while not done:
    107             if agent.render:
    108                 env.render()
    109 
    110             action = agent.get_action(state)
    111             next_state, reward, done, info = env.step(action)
    112             next_state = np.reshape(next_state, [1, state_size])
    113             # if an action make the episode end, then gives penalty of -100
    114             reward = reward if not done or score == 499 else -100
    115 
    116             agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次
    117 
    118             score += reward
    119             state = next_state
    120 
    121             if done:
    122                 # every episode, plot the play time
    123                 score = score if score == 500.0 else score + 100
    124                 scores.append(score)
    125                 episodes.append(e)
    126                 pylab.plot(episodes, scores, 'b')
    127                 pylab.savefig("./save_graph/cartpole_a2c.png")
    128                 print("episode:", e, "  score:", score)
    129 
    130                 # if the mean of scores of last 10 episode is bigger than 490
    131                 # stop training
    132                 if np.mean(scores[-min(10, len(scores)):]) > 490:
    133                     sys.exit()
    134 
    135         # save the model
    136         if e % 50 == 0:
    137             agent.actor.save_weights("./save_model/cartpole_actor.h5")
    138             agent.critic.save_weights("./save_model/cartpole_critic.h5")
  • 相关阅读:
    (转)CString,int,string,char*之间的转换
    Nt函数原型
    VC 创建窗口
    设置网络延时大小
    利用寄存器进入栈值交换
    控制用户控件里面的值
    jquery animate()动画函数
    终于有点时间了
    JQuery学习笔记(三)遮罩层、阴影层
    jquery表格隔行换色
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250161.html
Copyright © 2011-2022 走看看