zoukankan      html  css  js  c++  java
  • DeepNetwork---tensorflow实现

    https://github.com/zle1992/Reinforcement_Learning_Game

    DeepQNetwork.py
      1 import numpy as np 
      2 import tensorflow as tf
      3 from abc import ABCMeta, abstractmethod
      4 np.random.seed(1)
      5 tf.set_random_seed(1)
      6 
      7 import logging  # 引入logging模块
      8 logging.basicConfig(level=logging.DEBUG,
      9                     format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # logging.basicConfig函数对日志的输出格式及方式做相关配置
     10 # 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上
     11 
     12 tfconfig = tf.ConfigProto()
     13 tfconfig.gpu_options.allow_growth = True
     14 session = tf.Session(config=tfconfig)
     15 
     16 
     17 class DeepQNetwork(object):
     18     __metaclass__ = ABCMeta
     19     """docstring for DeepQNetwork"""
     20     def __init__(self, 
     21             n_actions,
     22             n_features,
     23             learning_rate,
     24             reward_decay,
     25             e_greedy,
     26             replace_target_iter,
     27             memory_size,
     28             e_greedy_increment,
     29             output_graph,
     30             log_dir,
     31             ):
     32         super(DeepQNetwork, self).__init__()
     33         
     34         self.n_actions = n_actions
     35         self.n_features = n_features
     36         self.learning_rate=learning_rate
     37         self.gamma=reward_decay
     38         self.epsilon_max=e_greedy
     39         self.replace_target_iter=replace_target_iter
     40         self.memory_size=memory_size
     41         self.epsilon_increment=e_greedy_increment
     42         self.output_graph=output_graph
     43         self.lr =learning_rate
     44         # total learning step
     45         self.learn_step_counter = 0
     46         self.log_dir = log_dir
     47        
     48  
     49 
     50         self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
     51         self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next')
     52 
     53         self.r = tf.placeholder(tf.float32,[None,],name='r')
     54         self.a = tf.placeholder(tf.int32,[None,],name='a')
     55 
     56 
     57         self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
     58         self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False)
     59 
     60 
     61 
     62         with tf.variable_scope('q_target'):
     63             self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_')    # shape=(None, )
     64         with tf.variable_scope('q_eval'):
     65             a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
     66             self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices)    # shape=(None, )
     67         with tf.variable_scope('loss'):
     68             self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
     69         with tf.variable_scope('train'):
     70             self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
     71 
     72 
     73 
     74         t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
     75         e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')
     76 
     77         with tf.variable_scope("hard_replacement"):
     78             self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)]
     79 
     80 
     81        
     82         self.sess = tf.Session()
     83         if self.output_graph:
     84             tf.summary.FileWriter(self.log_dir,self.sess.graph)
     85 
     86         self.sess.run(tf.global_variables_initializer())
     87         
     88         self.cost_his =[]
     89 
     90     @abstractmethod
     91     def _build_q_net(self,x,scope,trainable):
     92         raise NotImplementedError
     93 
     94     def learn(self,data):
     95 
     96 
     97          # check to replace target parameters
     98         if self.learn_step_counter % self.replace_target_iter == 0:
     99             self.sess.run(self.target_replace_op)
    100             print('
    target_params_replaced
    ')
    101 
    102         batch_memory_s = data['s'], 
    103         batch_memory_a =  data['a'], 
    104         batch_memory_r = data['r'], 
    105         batch_memory_s_ = data['s_'], 
    106         _, cost = self.sess.run(
    107             [self._train_op, self.loss],
    108             feed_dict={
    109                 self.s: batch_memory_s,
    110                 self.a: batch_memory_a,
    111                 self.r: batch_memory_r,
    112                 self.s_next: batch_memory_s_,
    113             })
    114         self.cost_his.append(cost)
    115 
    116         # increasing epsilon
    117         self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
    118         self.learn_step_counter += 1
    119 
    120 
    121 
    122 
    123     def choose_action(self,s): 
    124         s = s[np.newaxis,:]
    125         aa = np.random.uniform()
    126         #print("epsilon_max",self.epsilon_max)
    127         if aa < self.epsilon_max:
    128             action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
    129             action = np.argmax(action_value)
    130         else:
    131             action = np.random.randint(0,self.n_actions)
    132         return action
    Memory.py
     1 import numpy as np 
     2 np.random.seed(1)
     3 class Memory(object):
     4     """docstring for Memory"""
     5     def __init__(self,
     6             n_actions,
     7             n_features,
     8             memory_size):
     9         super(Memory, self).__init__()
    10         self.memory_size = memory_size
    11         self.cnt =0 
    12 
    13         self.s = np.zeros([memory_size]+n_features)
    14         self.a = np.zeros([memory_size,])
    15         self.r =  np.zeros([memory_size,])
    16         self.s_ = np.zeros([memory_size]+n_features)
    17         
    18     def store_transition(self,s, a, r, s_):
    19         #logging.info('store_transition')
    20         index = self.cnt % self.memory_size
    21         self.s[index] = s
    22         self.a[index] = a
    23         self.r[index] =  r
    24         self.s_[index] =s_
    25         self.cnt+=1
    26 
    27     def sample(self,n):
    28         #logging.info('sample')
    29         #assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
    30         N = min(self.memory_size,self.cnt)
    31         indices = np.random.choice(N,size=n)
    32         d ={}
    33         d['s'] = self.s[indices][0]
    34         d['s_'] = self.s_[indices][0]
    35         d['r'] = self.r[indices][0]
    36         d['a'] = self.a[indices][0]
    37         return d

    主函数

      1 import gym
      2 import numpy as np 
      3 import tensorflow as tf
      4 
      5 from Memory import Memory
      6 from DeepQNetwork import DeepQNetwork
      7 
      8 np.random.seed(1)
      9 tf.set_random_seed(1)
     10 
     11 import logging  # 引入logging模块
     12 logging.basicConfig(level=logging.DEBUG,
     13                     format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # logging.basicConfig函数对日志的输出格式及方式做相关配置
     14 # 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上
     15 
     16 tfconfig = tf.ConfigProto()
     17 tfconfig.gpu_options.allow_growth = True
     18 session = tf.Session(config=tfconfig)
     19 
     20 class DeepQNetwork4CartPole(DeepQNetwork):
     21     """docstring for ClassName"""
     22     def __init__(self, **kwargs):
     23         super(DeepQNetwork4CartPole, self).__init__(**kwargs)
     24     
     25     def _build_q_net(self,x,scope,trainable):
     26         w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)
     27 
     28         with tf.variable_scope(scope):
     29             e1 = tf.layers.dense(inputs=x, 
     30                     units=32, 
     31                     bias_initializer = b_initializer,
     32                     kernel_initializer=w_initializer,
     33                     activation = tf.nn.relu,
     34                     trainable=trainable)  
     35             q = tf.layers.dense(inputs=e1, 
     36                     units=self.n_actions, 
     37                     bias_initializer = b_initializer,
     38                     kernel_initializer=w_initializer,
     39                     activation = tf.nn.sigmoid,
     40                     trainable=trainable) 
     41 
     42         return q  
     43         
     44 
     45 
     46 
     47 batch_size = 64
     48 
     49 memory_size  =2000
     50 #env = gym.make('Breakout-v0') #离散
     51 env = gym.make('CartPole-v0') #离散
     52 
     53 
     54 n_features= list(env.observation_space.shape)
     55 n_actions= env.action_space.n
     56 
     57 env = env.unwrapped
     58 
     59 def run():
     60    
     61     RL = DeepQNetwork4CartPole(
     62         n_actions=n_actions,
     63         n_features=n_features,
     64         learning_rate=0.01,
     65         reward_decay=0.9,
     66         e_greedy=0.9,
     67         replace_target_iter=200,
     68         memory_size=memory_size,
     69         e_greedy_increment=None,
     70         output_graph=True,
     71         log_dir = 'log/DeepQNetwork4CartPole/',
     72         )
     73 
     74     memory = Memory(n_actions,n_features,memory_size=memory_size)
     75   
     76 
     77     step = 0
     78     ep_r = 0
     79     for episode in range(2000):
     80         # initial observation
     81         observation = env.reset()
     82 
     83         while True:
     84             
     85 
     86             # RL choose action based on observation
     87             action = RL.choose_action(observation)
     88             # logging.debug('action')
     89             # print(action)
     90             # RL take action and get_collectiot next observation and reward
     91             observation_, reward, done, info=env.step(action) # take a random action
     92             
     93             # the smaller theta and closer to center the better
     94             x, x_dot, theta, theta_dot = observation_
     95             r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
     96             r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
     97             reward = r1 + r2
     98 
     99 
    100 
    101 
    102             memory.store_transition(observation, action, reward, observation_)
    103             
    104             
    105             if (step > 200) and (step % 5 == 0):
    106                
    107                 data = memory.sample(batch_size)
    108                 RL.learn(data)
    109                 #print('step:%d----reward:%f---action:%d'%(step,reward,action))
    110             # swap observation
    111             observation = observation_
    112             ep_r += reward
    113             # break while loop when end of this episode
    114             if(episode>700): 
    115                 env.render()  # render on the screen
    116             if done:
    117                 print('episode: ', episode,
    118                       'ep_r: ', round(ep_r, 2),
    119                       ' epsilon: ', round(RL.epsilon_max, 2))
    120                 ep_r = 0
    121 
    122                 break
    123             step += 1
    124 
    125     # end of game
    126     print('game over')
    127     env.destroy()
    128 
    129 def main():
    130  
    131     run()
    132 
    133 
    134 
    135 if __name__ == '__main__':
    136     main()
    137     #run2()
  • 相关阅读:
    使用excel2003中的solver解决最优化问题
    图的邻接表存储方式的建立
    LINUX下使用VI
    LINUX下基本命令
    应用程序各对象创建的顺序
    zookeeper常遇错误详解
    MapReduce_partition
    MapReduce_TopK
    MapReduce_MaxValue
    Hbase用java基础操作
  • 原文地址:https://www.cnblogs.com/zle1992/p/10241794.html
Copyright © 2011-2022 走看看