zoukankan      html  css  js  c++  java
  • 深度增强学习--Deep Q Network

    从这里开始换个游戏演示,cartpole游戏

    Deep Q Network

    实例代码

      1 import sys
      2 import gym
      3 import pylab
      4 import random
      5 import numpy as np
      6 from collections import deque
      7 from keras.layers import Dense
      8 from keras.optimizers import Adam
      9 from keras.models import Sequential
     10 
     11 EPISODES = 300
     12 
     13 
     14 # DQN Agent for the Cartpole
     15 # it uses Neural Network to approximate q function,使用神经网络近似q-learning的q函数
     16 # and experience replay memory & fixed target q network
     17 class DQNAgent:
     18     def __init__(self, state_size, action_size):
     19         # if you want to see Cartpole learning, then change to True
     20         self.render = True
     21         self.load_model = False
     22 
     23         # get size of state and action
     24         self.state_size = state_size
     25         self.action_size = action_size
     26 
     27         # These are hyper parameters for the DQN
     28         self.discount_factor = 0.99
     29         self.learning_rate = 0.001
     30         self.epsilon = 1.0
     31         self.epsilon_decay = 0.999
     32         self.epsilon_min = 0.01
     33         self.batch_size = 64
     34         self.train_start = 1000
     35         # create replay memory using deque
     36         self.memory = deque(maxlen=2000)
     37 
     38         # create main model and target model
     39         self.model = self.build_model()
     40         self.target_model = self.build_model()
     41 
     42         # initialize target model
     43         self.update_target_model()
     44 
     45         if self.load_model:
     46             self.model.load_weights("./save_model/cartpole_dqn.h5")
     47 
     48     # approximate Q function using Neural Network
     49     # state is input and Q Value of each action is output of network
     50     def build_model(self):
     51         model = Sequential()
     52         model.add(Dense(24, input_dim=self.state_size, activation='relu',
     53                         kernel_initializer='he_uniform'))
     54         model.add(Dense(24, activation='relu',
     55                         kernel_initializer='he_uniform'))
     56         model.add(Dense(self.action_size, activation='linear',
     57                         kernel_initializer='he_uniform'))
     58         model.summary()
     59         model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
     60         return model
     61 
     62     # after some time interval update the target model to be same with model
     63     def update_target_model(self):
     64         self.target_model.set_weights(self.model.get_weights())
     65 
     66     # get action from model using epsilon-greedy policy
     67     def get_action(self, state):
     68         if np.random.rand() <= self.epsilon:
     69             return random.randrange(self.action_size)
     70         else:
     71             q_value = self.model.predict(state)#2,q(s,a),利用模型预测不同action的q值,选大的作为下一action
     72             return np.argmax(q_value[0])
     73 
     74     # save sample <s,a,r,s'> to the replay memory
     75     def append_sample(self, state, action, reward, next_state, done):
     76         self.memory.append((state, action, reward, next_state, done))
     77         if self.epsilon > self.epsilon_min:
     78             self.epsilon *= self.epsilon_decay
     79 
     80     # pick samples randomly from replay memory (with batch_size)
     81     def train_model(self):
     82         if len(self.memory) < self.train_start:
     83             return
     84         import pdb; pdb.set_trace()
     85         batch_size = min(self.batch_size, len(self.memory))
     86         mini_batch = random.sample(self.memory, batch_size)#64list
     87         #(array([[-0.04263461, -0.00657423,  0.00506589, -0.00200269]]), 0, 1.0, array([[-0.04276609, -0.20176846,  0.00502584,  0.29227427]]), False)
     88 
     89         update_input = np.zeros((batch_size, self.state_size))
     90         update_target = np.zeros((batch_size, self.state_size))
     91         action, reward, done = [], [], []
     92 
     93         for i in range(self.batch_size):
     94             update_input[i] = mini_batch[i][0]
     95             action.append(mini_batch[i][1])
     96             reward.append(mini_batch[i][2])
     97             update_target[i] = mini_batch[i][3]
     98             done.append(mini_batch[i][4])
     99 
    100         target = self.model.predict(update_input)#(64,2)
    101         target_val = self.target_model.predict(update_target)#(64, 2)
    102 
    103         for i in range(self.batch_size):
    104             # Q Learning: get maximum Q value at s' from target model
    105             if done[i]:
    106                 target[i][action[i]] = reward[i]
    107             else:
    108                 target[i][action[i]] = reward[i] + self.discount_factor * (
    109                     np.amax(target_val[i]))#off-policy 更新
    110 
    111         # and do the model fit!
    112         self.model.fit(update_input, target, batch_size=self.batch_size,
    113                        epochs=1, verbose=0)
    114 
    115 
    116 if __name__ == "__main__":
    117     # In case of CartPole-v1, maximum length of episode is 500
    118     env = gym.make('CartPole-v1')
    119     # get size of state and action from environment
    120     state_size = env.observation_space.shape[0]#4
    121     action_size = env.action_space.n#2
    122 
    123     agent = DQNAgent(state_size, action_size)
    124 
    125     scores, episodes = [], []
    126 
    127     for e in range(EPISODES):
    128         done = False
    129         score = 0
    130         state = env.reset()
    131         state = np.reshape(state, [1, state_size])
    132 
    133         while not done:
    134             if agent.render:
    135                 env.render()
    136 
    137             # get action for the current state and go one step in environment
    138             action = agent.get_action(state)
    139             next_state, reward, done, info = env.step(action)
    140             next_state = np.reshape(next_state, [1, state_size])
    141             # if an action make the episode end, then gives penalty of -100
    142             reward = reward if not done or score == 499 else -100
    143 
    144             # save the sample <s, a, r, s'> to the replay memory
    145             agent.append_sample(state, action, reward, next_state, done)
    146             # every time step do the training
    147             agent.train_model()
    148             score += reward
    149             state = next_state
    150 
    151             if done:
    152                 # every episode update the target model to be same with model
    153                 agent.update_target_model()
    154 
    155                 # every episode, plot the play time
    156                 score = score if score == 500 else score + 100
    157                 scores.append(score)
    158                 episodes.append(e)
    159                 pylab.plot(episodes, scores, 'b')
    160                 pylab.savefig("./save_graph/cartpole_dqn.png")
    161                 print("episode:", e, "  score:", score, "  memory length:",
    162                       len(agent.memory), "  epsilon:", agent.epsilon)
    163 
    164                 # if the mean of scores of last 10 episode is bigger than 490
    165                 # stop training
    166                 if np.mean(scores[-min(10, len(scores)):]) > 490:
    167                     sys.exit()
    168 
    169         # save the model
    170         if e % 50 == 0:
    171             agent.model.save_weights("./save_model/cartpole_dqn.h5")
  • 相关阅读:
    最小二乘法求回归直线方程的推导过程
    最小二乘法求回归直线方程的推导过程
    Redis过期键的删除策略
    Redis过期键的删除策略
    最小二乘法求回归直线方程的推导过程
    最小二乘法求回归直线方程的推导过程
    不用第三方实现内网穿透
    不用第三方实现内网穿透
    X Redo丢失的4种情况及处理方法
    Problem D: 逆置链式链表(线性表)
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250150.html
Copyright © 2011-2022 走看看