zoukankan      html  css  js  c++  java
  • tensorflow增强学习应用于一个小游戏

    首先需要安装gym模块,提供游戏的。

    1,所需模块

    import tensorflow as tf
    import numpy as np
    import gym
    import random
    from collections import deque
    from keras.utils.np_utils import to_categorical

    2,自定义一个简单的3层Dense Model

    # 自定义Model
    class QNetwork(tf.keras.Model):
        def __init__(self):
            super().__init__()
    #         简单的3个Dense
            self.dense1=tf.keras.layers.Dense(24,activation='relu')
            self.dense2=tf.keras.layers.Dense(24,activation='relu')
            self.dense3=tf.keras.layers.Dense(2)
        def call(self,inputs):
            x=self.dense1(inputs)
            x=self.dense2(x)
            x=self.dense3(x)
            return x
        def predict(self,inputs):
            q_values=self(inputs)#调用call
            return tf.argmax(q_values,axis=-1)

    3,定义相关参数

    # 游戏环境,实例化一个游戏
    env=gym.make('CartPole-v1')
    model=QNetwork()
    
    # 循环轮数设置小一点,50就可以了
    num_episodes=500
    num_exploration=100
    max_len=1000
    batch_size=32
    lr=1e-3
    gamma=1.
    initial_epsilon=1.
    final_epsilon=0.01
    replay_buffer=deque(maxlen=10000)
    
    epsilon=initial_epsilon
    # tensorflow2.0
    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=lr)

    4,训练,测试

    for i in range(num_episodes):
        # 初始化环境
        state=env.reset()
    #     逐渐衰减,至final_epsilon
        epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon)
        for t in range(max_len):
    #         当前帧绘制到屏幕
            env.render()
    #         以epsilon的概率随机行动,epsilon是衰减的,说明游戏动作会越来越稳定
            if random.random()<epsilon:
                action=env.action_space.sample()
            else:
    #             从当前状态预测一个动作
                action=model.predict(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32)).numpy()
                action=action[0]
    #         执行一步动作
            next_state,reward,done,info=env.step(action)
    #         奖励
            reward=-10.if done else reward
    #         缓存
            replay_buffer.append((state,action,reward,next_state,done))
            state=next_state
            if done:
                print('episode %d,epsilon %f,score %d'%(i,epsilon,t))
                break
    #         预测batch_size步后执行
            if len(replay_buffer)>=batch_size:
                # 随机获取一个batch的数据
                batch_state,batch_action,batch_reward,batch_next_state,batch_done=
                [np.array(a,dtype=np.float32) for a in zip(*random.sample(replay_buffer,batch_size))]
    #             下一个状态,由此得到的y为真实值
    #             预测值与真实值的计算看不太懂
                q_value=model(tf.constant(batch_next_state,dtype=tf.float32))
                y=batch_reward+(gamma*tf.reduce_max(q_value,axis=1))*(1-batch_done)
                with tf.GradientTape() as tape:
    #                 loss=tf.losses.mean_squared_error(labels=y,predictions=tf.reduce_sum(
    #                     model(tf.constant(batch_state))*tf.one_hot(batch_action,depth=2),axis=1))
                    loss=tf.losses.mean_squared_error(y,tf.reduce_sum(
                        model(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1))
                grads=tape.gradient(loss,model.variables)
                optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))

    最终会出现一个窗口,平衡游戏不断进行。。。

    上面注释部分因为tf.one_hot方法会报错。

  • 相关阅读:
    [题解] [NOIP2008] 双栈排序——关系的冲突至图论解法
    [搬运] [贪心]NOIP2011 观光公交
    [总结] 最短路径数问题
    [持续更新]一些zyys的题的集合
    [教程]Ubuntu下完整配置自动壁纸切换
    在NOILINUX下的简易VIM配置
    [模板]ST表浅析
    21、Android--RecyclerView
    20、Android--GridView
    19、Android--ListView
  • 原文地址:https://www.cnblogs.com/lunge-blog/p/11644598.html
Copyright © 2011-2022 走看看