zoukankan      html  css  js  c++  java
  • mxnet(gluon) 实现DQN简单小例子

    参考文献

    莫凡系列课程视频

    增强学习入门之Q-Learning


    关于增强学习的基本知识可以参考第二个链接,讲的挺有意思的。DQN的东西可以看第一个链接相关视频。课程中实现了Tensorflow和pytorch的示例代码。本文主要是改写成了gluon实现

    Q-learning的算法流程


    DQN的算法流程


    对于DQN的理解:

    增强学习中需要学习的东西是Q-table,决策表。而针对于state space空间太大的情形,很难甚至不可能构建这个决策表。而决策表其实就是一种映射 (s,a)->R, 那么这种映射可以通过网络来构建,于是就有了DQN

    image

    下面来看代码

    import mxnet as mx
    import mxnet.ndarray as nd
    import mxnet.gluon as gluon
    import numpy as np
    import mxnet.gluon.nn as nn
    import gym


    BATCH_SIZE=64                                             # 训练网络时的batchsize
    LR=0.01                                                         # 权重更新的学习率
    EPSILON=0.9                                                  # 每次以概率选择最有策略,有点类似于生物算法的思想
    GAMMA=0.5                                                    # 计算q_target是下一个状态收益对当前的影响
    TARGET_REPLACE_ITER=100                            # 保存网络参数,可以理解为上一次的映射,的频率
    MEMORY_CAPACITY=1000                                # 历史决策
    env = gym.make('CartPole-v0')                         # 调用OpenAI.gym构建的env
    env = env.unwrapped
    N_ACTIONS=env.action_space.n                       # 备选策略的个数
    N_STATES = env.observation_space.shape[0]    # 状态向量的长度


    # 定义所需要的网络,示例仅随意设置了几层

    class Net(nn.HybridBlock):
         def __init__(self,**kwargs):
             super(Net, self).__init__(**kwargs)
             with self.name_scope():
                 self.fc1 = nn.Dense(16, activation='relu')
                 self.fc2 = nn.Dense(32, activation='relu')
                 self.fc3 = nn.Dense(16, activation='relu')
                 self.out = nn.Dense(N_ACTIONS)
         def hybrid_forward(self, F, x):
             x = self.fc1(x)
             x = self.fc2(x)
             x = self.fc3(x)
             actions_value = self.out(x)
             return actions_value


    # 定义网络权重的拷贝方法。主要是因为DQN learning中采用off-policy更新,也就是说需要上一次的映射图,这可以使用网络上一次的权重保存,这个用以保存权重的网络只有前向功能,类似于查表,并不更新参数,直到满足一定条件时将当前网络参数再次存储

    def copy_params(src, dst):
         dst.initialize(force_reinit=True, ctx=mx.cpu())
         layer_names = ['dense0_weight', 'dense0_bias','dense1_weight','dense1_bias',
                      'dense2_weight','dense2_bias','dense3_weight','dense3_bias']
         for i in range(len(layer_names)):
             dst.get(layer_names[i]).set_data(src.get(layer_names[i]).data())



    # 定义DQN类,包含网络、策略选择、保存记录等

    class DQN(object):
         def __init__(self):
             self.eval_net, self.target_net = Net(), Net()
             self.eval_net.initialize()
             self.target_net.initialize()
             x=nd.random_uniform(shape=(1,N_STATES))
             _ = self.eval_net(x)
             _ = self.target_net(x)                # mxnet的延迟初始化特性
             self.learn_step_counter = 0
             self.memory_counter = 0
             self.memory = np.zeros(shape=(MEMORY_CAPACITY, N_STATES*2+2))
             # 每一行存储的是当前状态,选择的action, 当前的回报, 下一步的状态
             self.trainer = gluon.Trainer(self.eval_net.collect_params(), 'sgd',
                                         {'learning_rate': LR,'wd':1e-4})
             self.loss_func = gluon.loss.L2Loss()
             self.cost_his=[]
         def choose_action(self, x):
             if np.random.uniform()<EPSILON:
                 # EPSILON的概率选择最可能动作
                 x = nd.array([x])
                 actions_value = self.eval_net(x)
                 action = int(nd.argmax(actions_value, axis=1).asscalar())
             else:
                 action = np.random.randint(0, N_ACTIONS)
             return action
         def store_transition(self,s,a,r,s_):
             # 存储历史纪录
             transition = np.hstack((s,[a,r],s_))
             index = self.memory_counter % MEMORY_CAPACITY
             # 主要是为了循环利用存储空间
             self.memory[index,:] = transition
             self.memory_counter += 1
            
         def learn(self):
             if self.learn_step_counter % TARGET_REPLACE_ITER==0:
                 # 每学习一定间隔之后,将当前的状态
                 copy_params(self.eval_net.collect_params(), self.target_net.collect_params())
                
             self.learn_step_counter += 1
            
             sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
             # 随机选择一组状态
             b_memory = self.memory[sample_index,:]
           
             b_s = nd.array(b_memory[:,:N_STATES])
             b_a = nd.array(b_memory[:,N_STATES:N_STATES+1])
             b_r = nd.array(b_memory[:,N_STATES+1:N_STATES+2])
             b_s_= nd.array(b_memory[:,-N_STATES:])
             with mx.autograd.record():
                 q_eval = self.eval_net(b_s) # 预估值
                 with mx.autograd.pause():
                     q_next = self.target_net(b_s_) # 历史值 batch x N_ACTIONS
                 q_target = b_r + GAMMA*nd.max(q_next, axis=1)
                 loss = self.loss_func(q_eval, q_target)
            
             self.cost_his.append(nd.mean(loss).asscalar())
             loss.backward()
             self.trainer.step(BATCH_SIZE)
            
         def plot_cost(self):
             import matplotlib.pyplot as plt
             plt.plot(np.arange(len(self.cost_his)), self.cost_his)
             plt.ylabel('Cost')
             plt.xlabel('training steps')
             plt.show()


    # 训练
    dqn = DQN()
    for i_episode in range(500):
         s = env.reset()
         while True:
             env.render()
             a = dqn.choose_action(s)
             s_, r, done, info = env.step(a)# 到达的状态,收益,是否结束 

            x,x_dot, theta, theta_dot = s_
             r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
             r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians-0.5
             r = r1 + r2

            dqn.store_transition(s,a,r,s_)
             if dqn.memory_counter > MEMORY_CAPACITY:
                 dqn.learn()

             if done:
                 break
            
             s = s_
    dqn.plot_cost() 


    loss曲线

    loss


    训练的loss似乎并没有收敛,还在找原因


    ps. 第一次使用open live writer写博客,体验很差!!!!!我需要公式、代码和图片的支持。。。。还在寻找中

  • 相关阅读:
    Cable master--hdu1551(二分法)
    Pie--hdu1969(二分法)
    Ice_cream's world I--hdu2120
    How Many Tables--hdu1213(并查集)
    畅通工程--hdu1232(并查集)
    小希的迷宫--hdu1272(并查集)
    More is better--hdu1856(并查集)
    Windows Message Queue--hdu1509
    期末考试--nyoj-757
    网络开发之使用Web Service和使用WCF服务
  • 原文地址:https://www.cnblogs.com/YiXiaoZhou/p/8145499.html
Copyright © 2011-2022 走看看