zoukankan      html  css  js  c++  java
  • Reinforcement Learning (DQN) 中经验池详细解释

    一般DQN中的经验池类,都类似于下面这段代码。

    import random
    from collections import namedtuple, deque
    
    Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward'))
    
    # 经验池类
    class ReplayMemory(object):
    
        def __init__(self, capacity):
            self.capacity = capacity        # 容量
            self.memory = []
            self.position = 0
    
        # 将四元组压入经验池
        def push(self, *args):
            if len(self.memory) < self.capacity:
                self.memory.append(None)
            self.memory[self.position] = Transition(*args)
            self.position = (self.position + 1) % self.capacity
    
        # 从经验池中随机压出一个四元组
        def sample(self, batch_size):
            transitions = random.sample(self.memory, batch_size)
            batch = Transition(*zip(*transitions))
            return batch
    
        def __len__(self):
            return len(self.memory)
    

    对Python不太熟悉的我里边就有两点比较迷惑,一个是namedtuple()方法,一个是sample方法的倒数第二行,为什么要这样处理。

    第一点,namedtuple()是继承自tuple的子类,namedtuple()方法能够创建一个和tuple类似的对象,而且对象拥有可访问的属性。

    第二点,也就是sample方法中的倒数第二行,这里进行了一个转换, 将batch_size个四元组,转换成,四个元祖,每个元祖一共有batch_size项,这里放个程序解释一下。

    import random
    from collections import namedtuple
    
    if __name__ == '__main__':
    
        batch_size = 3
        Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward'))
    
        a=Transition(state=1,next_state=2,action=3,reward=4)
        b=Transition(state=11,next_state=12,action=13,reward=14)
        c=Transition(state=21,next_state=22,action=23,reward=24)
        d=Transition(state=31,next_state=32,action=33,reward=34)
        e=Transition(state=41,next_state=42,action=43,reward=44)
    
        f=[a,b,c,d,e]
    
        # 从f中随机抽取batch_size个数据
        t=random.sample(f,batch_size)
    
        print("随机抽取的batch_size个四元祖是:")
        for i in range(batch_size):
            print(t[i])
        print()
    
        # 将t进行解压操作
        print("将四元组进行解压后是:")
        print(*zip(*t))
        print()
    
        # 将t进行解压操作,再进行Transition转换
        # 将batch_size个四元组,转换成,四个元组,每个元组一共有batch_size项
        print("将四元组进行解压后再进行Transition转换后是:")
        batch=Transition(*zip(*t))
        print(batch)
    

    输出结果:

    随机抽取的batch_size个四元祖是:
    Transition(state=21, next_state=22, action=23, reward=24)
    Transition(state=11, next_state=12, action=13, reward=14)
    Transition(state=41, next_state=42, action=43, reward=44)
    
    将四元组进行解压后是:
    (21, 11, 41) (22, 12, 42) (23, 13, 43) (24, 14, 44)
    
    将四元组进行解压后再进行Transition转换后是:
    Transition(state=(21, 11, 41), next_state=(22, 12, 42), action=(23, 13, 43), reward=(24, 14, 44))
    
  • 相关阅读:
    移动端touch事件获取事件坐标
    详解webpack中的hash、chunkhash、contenthash区别
    textarea placeholder 设置主动换行
    js-xlsx的使用
    关于Blob对象的介绍与使用
    spring boot zuul集成kubernetes等第三方登录
    Spring Boot 获取yaml配置文件信息
    spring boot @Value源码解析
    java.lang.StackOverflowError解决
    Jpa 重写方言dialect 使用oracle / mysql 数据库自定义函数
  • 原文地址:https://www.cnblogs.com/52dxer/p/14139911.html
Copyright © 2011-2022 走看看