zoukankan      html  css  js  c++  java
  • Q-learning简明实例Java代码实现

    在《Q-learning简明实例》中我们介绍了Q-learning算法的简单例子,从中我们可以总结出Q-learning算法的基本思想

    本次选择的经验得分 = 本次选择的反馈得分 + 本次选择后场景的历史最佳经验得分

    其中反馈得分是单个步骤的价值分值(固定的分值),经验得分是完成目标的学习分值(动态的分值)。

    简明实例的Java实现如下

    package com.coshaho.learn.qlearning;
    
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.List;
    import java.util.Random;
    
    /**
     * 
     * QLearning.java Create on 2017年9月4日 下午10:08:49    
     *    
     * 类功能说明:   QLearning简明例子实现
     *
     * Copyright: Copyright(c) 2013 
     * Company: COSHAHO
     * @Version 1.0
     * @Author coshaho
     */
    public class QLearning 
    {
        FeedbackMatrix R = new FeedbackMatrix();
        
        ExperienceMatrix Q = new ExperienceMatrix();
        
        public static void main(String[] args)
        {
            QLearning ql = new QLearning();
            
            for(int i = 0; i < 500; i++)
            {
                Random random = new Random();
                int x = random.nextInt(100) % 6;
                
                System.out.println("第" + i + "次学习, 初始房间是" + x);
                ql.learn(x);
                System.out.println();
            }
        }
        
        public void learn(int x)
        {
            do
            {
                // 随机选择一个联通的房间进入
                int y =  chooseRandomRY(x);
                
                // 获取以进入的房间为起始点的历史最佳得分
                int qy = getMaxQY(y);
                
                // 计算此次移动的得分
                int value = calculateNewQ(x, y, qy);
                Q.set(x, y, value);
                x = y;
            }
            // 走出房间则学习结束
            while(5 != x);
            
            Q.print();
        }
        
        public int chooseRandomRY(int x)
        {
            int[] qRow = R.getRow(x);
            List<Integer> yValues = new ArrayList<Integer>();
            for(int i = 0; i < qRow.length; i++)
            {
                if(qRow[i] >= 0)
                {
                    yValues.add(i);
                }
            }
    
            Random random = new Random();
            int i = random.nextInt(yValues.size()) % yValues.size();
            return yValues.get(i);
        }
        
        public int getMaxQY(int x)
        {
            int[] qRow = Q.getRow(x);
            int length = qRow.length;
            List<YAndValue> yValues = new ArrayList<YAndValue>();
            for(int i = 0; i < length; i++)
            {
                YAndValue yv = new YAndValue(i, qRow[i]);
                yValues.add(yv);
            }
            
            Collections.sort(yValues);
            int num = 1;
            int value = yValues.get(0).getValue();
            for(int i = 1; i < length; i++)
            {
                if(yValues.get(i).getValue() == value)
                {
                    num = i + 1;
                }
                else
                {
                    break;
                }
            }
            
            Random random = new Random();
            int i = random.nextInt(num) % num;
            return yValues.get(i).getY();
        }
        
        // Q(x,y) = R(x,y) + 0.8 * max(Q(y,i))
        public int calculateNewQ(int x, int y, int qy)
        {
            return (int) (R.get(x, y) + 0.8 * Q.get(y, qy));
        }
        
        public static class YAndValue implements Comparable<YAndValue>
        {
            int y;
            int value;
            
            public int getY() {
                return y;
            }
            public void setY(int y) {
                this.y = y;
            }
            public int getValue() {
                return value;
            }
            public void setValue(int value) {
                this.value = value;
            }
            public YAndValue(int y, int value)
            {
                this.y = y;
                this.value = value;
            }
            public int compareTo(YAndValue o) 
            {
                return o.getValue() - this.value;
            }
        }
    }
    
    package com.coshaho.learn.qlearning;
    
    /**
     * 
     * FeedbackMatrix.java Create on 2017年9月4日 下午9:52:41    
     *    
     * 类功能说明:   反馈矩阵
     *
     * Copyright: Copyright(c) 2013 
     * Company: COSHAHO
     * @Version 1.0
     * @Author coshaho
     */
    public class FeedbackMatrix 
    {
        public int get(int x, int y)
        {
            return R[x][y];
        }
        
        public int[] getRow(int x)
        {
            return R[x];
        }
        
        private static int[][] R = new int[6][6];
        static 
        {
            R[0][0] = -1;
            R[0][1] = -1;
            R[0][2] = -1;
            R[0][3] = -1;
            R[0][4] = 0;
            R[0][5] = -1;
            
            R[1][0] = -1;
            R[1][1] = -1;
            R[1][2] = -1;
            R[1][3] = 0;
            R[1][4] = -1;
            R[1][5] = 100;
            
            R[2][0] = -1;
            R[2][1] = -1;
            R[2][2] = -1;
            R[2][3] = 0;
            R[2][4] = -1;
            R[2][5] = -1;
            
            R[3][0] = -1;
            R[3][1] = 0;
            R[3][2] = 0;
            R[3][3] = -1;
            R[3][4] = 0;
            R[3][5] = -1;
            
            R[4][0] = 0;
            R[4][1] = -1;
            R[4][2] = -1;
            R[4][3] = 0;
            R[4][4] = -1;
            R[4][5] = 100;
            
            R[5][0] = -1;
            R[5][1] = 0;
            R[5][2] = -1;
            R[5][3] = -1;
            R[5][4] = 0;
            R[5][5] = 100;
        }
    }
    
    package com.coshaho.learn.qlearning;
    
    /**
     * 
     * ExperienceMatrix.java Create on 2017年9月4日 下午10:03:08    
     *    
     * 类功能说明:   经验矩阵
     *
     * Copyright: Copyright(c) 2013 
     * Company: COSHAHO
     * @Version 1.0
     * @Author coshaho
     */
    public class ExperienceMatrix 
    {
        public int get(int x, int y)
        {
            return Q[x][y];
        }
        
        public int[] getRow(int x)
        {
            return Q[x];
        }
        
        public void set(int x, int y, int value)
        {
            Q[x][y] = value;
        }
        
        public void print()
        {
            for(int i = 0; i < 6; i++)
            {
                for(int j = 0; j < 6; j++)
                {
                    String s = Q[i][j] + "  ";
                    if(Q[i][j] < 10)
                    {
                        s = s + "  ";
                    }
                    else if(Q[i][j] < 100)
                    {
                        s = s + " ";
                    }
                    System.out.print(s);
                }
                System.out.println();
            }
        }
        
        private static int[][] Q = new int[6][6];
        static
        {
            Q[0][0] = 0;
            Q[0][1] = 0;
            Q[0][2] = 0;
            Q[0][3] = 0;
            Q[0][4] = 0;
            Q[0][5] = 0;
            
            Q[1][0] = 0;
            Q[1][1] = 0;
            Q[1][2] = 0;
            Q[1][3] = 0;
            Q[1][4] = 0;
            Q[1][5] = 0;
            
            Q[2][0] = 0;
            Q[2][1] = 0;
            Q[2][2] = 0;
            Q[2][3] = 0;
            Q[2][4] = 0;
            Q[2][5] = 0;
            
            Q[3][0] = 0;
            Q[3][1] = 0;
            Q[3][2] = 0;
            Q[3][3] = 0;
            Q[3][4] = 0;
            Q[3][5] = 0;
            
            Q[4][0] = 0;
            Q[4][1] = 0;
            Q[4][2] = 0;
            Q[4][3] = 0;
            Q[4][4] = 0;
            Q[4][5] = 0;
            
            Q[5][0] = 0;
            Q[5][1] = 0;
            Q[5][2] = 0;
            Q[5][3] = 0;
            Q[5][4] = 0;
            Q[5][5] = 0;
        }
    }

    经过500次计算得到如下结果

    第499次学习, 初始房间是1
    0    0    0    0    396  0    
    0    0    0    316  0    496  
    0    0    0    316  0    0    
    0    396  252  0    396  0    
    316  0    0    316  0    496  
    0    396  0    0    396  496  
    

    此时,我们从任意一个房间进入,每次选取最高分值步骤移动,总可以找到最短的逃离路径。

  • 相关阅读:
    《锋利的jQuery》补充笔记
    sass学习笔记
    《HTML5与CSS3基础教程》学习笔记 ——补充
    ajax常见问题(部分)
    html新特性(部分)
    less 笔记
    《JavaScript高级程序设计》补充笔记2
    《JavaScript高级程序设计》补充笔记1
    《CSS3秘笈》备忘录
    显示实现接口的好处c#比java好的地方
  • 原文地址:https://www.cnblogs.com/coshaho/p/7497025.html
Copyright © 2011-2022 走看看