zoukankan      html  css  js  c++  java
  • dopamine源码解析之dqn_agent

    目录

    • epsilon函数
    • DQNAgent构造函数核心参数
    • DQNAgent核心函数
    • tf.make_template
    • 核心数据流图

    epsilon函数

    linearly_decaying_epsilon,线性的对epsilon进行递减,先保持1.0一段时间(warmup_steps),然后线性递减,最后递减到最小值之后维持这个最小值;

    DQNAgent构造函数核心参数

    • update_horizon,n-step中的n,后向观察的步数;
    • min_replay_history,在智能体进行训练之前,必须经历的step数量,智能体不能一开始就进行训练;
    • update_period,当前网络参数更新的周期;
    • target_update_period,当前网络参数更新到目标网络参数上的周期;

    DQNAgent核心函数

    • init,解析构造参数,准备输入的placeholder,构建网络结构,其中输入状态的深度是4,也就是说,输入的不是一张图像,而是4张堆叠的图像;
    • _get_network_type,获取网络类型,返回一个collections.namedtuple;
    • _network_template,网络模板,三层卷积,两层全连接;
    • _build_networks,构造网络结构,设计了online_convnet和target_convnet两种操作,分别用于构建当前和目标网络,使用了tf.make_template函数,对于同样一个操作,不论输入的是什么,都共享同样的网络参数;
    • _build_replay_buffer,构建经验重放缓冲;
    • _build_target_q_op,为q-learning生成一个目标的操作;
    • _build_train_op,训练的操作;
    • _build_sync_op,同步的操作;
    • begin_episode,开始一段周期,初始化state和action;
    • step,选择动作,如果是训练过程,需要记录transition;
    • end_episode,如果是训练过程,需要记录transition;
    • _select_action,根据模型选择动作;
    • _train_step,运行单个训练步骤,需要满足两个条件,第一,在经验缓冲中的帧数已经达到要求,第二,training_steps是update_period的整数倍;
    • _record_observation,记录一个观察;
    • _store_transition,记录一次转换;

    tf.make_template

    • tf.make_template,输入一个函数,返回一个包裹了该函数的操作,这个操作在第一次被调用的时候创建变量,然后在之后的每一次调用中,重用这些变量,这是实现变量共享的一种方法。

    核心数据流图

    graph BT state_ph-->|1|online_convnet(online_convnet) online_convnet-->|1|_net_outputs _net_outputs-->|1|_q_argmax _replay.states-->|2|online_convnet online_convnet-->|2|_replay_net_outputs _replay.next_states-->|3|target_convnet(target_convnet) target_convnet-->|3|_replay_next_target_net_outputs _replay_next_target_net_outputs-->|3|replay_next_qt_max replay_next_qt_max-->|3|target _replay_net_outputs-->|2|replay_chosen_q target-->|3|loss replay_chosen_q-->|2|loss online_convnet-->|4|target_convnet

    图中的数据流动包含4条线,解释如下:

    • 线1,在线动作,根据当前的状态state_ph,以及当前的在线策略模型online_convnet,按照epsilon贪心方式选择最优的动作_q_argmax;
    • 线2,训练动作,根据replay buffer中的记忆,以及当前的在线策略模型online_convnet,计算实际选择的动作,注意这里的选择是根据贪心方式,而不是epsilon贪心的方式选择的,这也是q-learning和sarsa算法最大的不同,也是off-policy和on-policy的根本区别;
    • 线3,训练动作,根据replay buffer中的记忆,以及当前的目标策略模型target_convnet,计算Q-learning中的目标;
    • 线2+线3,训练动作,根据线2计算出实际的Q值,以及线3计算出的目标Q值,进行训练,注意训练时,只有在线策略模型online_convnet会迭代,目标策略模型target_convnet并不迭代;
    • 线4,每间隔一定的周期(即target_update_period),就会把当前在线策略模型online_convnet的参数同步给目标策略模型target_convnet,完成对目标模型的更新;

    关于线2和线3,再说明一下,还记得Bellman目标是:

    Q_t = R_t + gamma^N * Q'_t+1
    

    其中,

    Q'_t+1 = argmax_a Q(S_t+1, a) or 0 if S_t is a terminal state
    

    线3计算的就相当于Q_t,是我们希望通过现有的在线策略模型逼近的目标,而线2计算的是当前在线策略模型的输出,因此线2和线3的差距,就是损失,利用这个损失就可以对线2中的在线策略模型中的参数进行训练。

  • 相关阅读:
    ok6410驱动usb摄像头
    自己动手写CPU之第五阶段(1)——流水线数据相关问题
    ListView嵌套ListView时发生:View too large to fit into drawing cache的问题
    算法导论 第8章 线性时间排序(计数排序、基数排序、桶排序)
    Android_通过ContentObserver监听短信数据变化
    【MyEcplise】导入项目报错:Errors running builder 'JavaScript Validator' on project '项目名'. java.lang.ClassCastException
    【js】js中const,var,let区别
    【Node.js】2.开发Node.js选择哪个IDE 开发工具呢
    【Node.js】1.安装步骤
    【POI】对于POI无法处理超大xls等文件,官方解决方法【已解决】【多线程提升速率待定】
  • 原文地址:https://www.cnblogs.com/jicanghai/p/9746936.html
Copyright © 2011-2022 走看看