zoukankan      html  css  js  c++  java
  • (六)Value Function Approximation-LSPI code (4)

    本篇是solver.py

      1 # -*- coding: utf-8 -*-
      2 """Contains main LSPI method and various LSTDQ solvers."""
      3 
      4 import abc
      5 import logging
      6 
      7 import numpy as np
      8 
      9 import scipy.linalg
     10 
     11 
     12 class Solver(object):#这里也出现一个继承ABC类的类了
     13 
     14     r"""ABC for LSPI solvers.
     15 
     16     Implementations of this class will implement the various LSTDQ algorithms
     17     with various linear algebra solving techniques. This solver will be used
     18     by the lspi.learn method. The instance will be called iteratively until
     19     the convergence parameters are satisified.
     20 
     21     """
     22 
     23     __metaclass__ = abc.ABCMeta#继承
     24 
     25     @abc.abstractmethod#必须覆盖的函数
     26     def solve(self, data, policy):#求解函数
     27         r"""Return one-step update of the policy weights for the given data.
     28             #该函数对于给出的数据更新一步权重
     29         Parameters#输入参数
     30         ----------
     31         data:#数据
     32             #求解器需要的数据,通常是一个元素是采样的列表,当然也可以是各种求解器支持的方法
     33             This is the data used by the solver. In most cases this will be
     34             a list of samples. But it can be anything supported by the specific
     35             Solver implementation's solve method.
     36         policy: Policy#策略
     37             当前的策略,要对它进行提升
     38             The current policy to find an improvement to.
     39 
     40         Returns
     41         -------
     42         numpy.array#输出的权重
     43             Return the new weights as determined by this method.
     44 
     45         """
     46         pass  # pragma: no cover
     47 
     48 
     49 class LSTDQSolver(Solver):#最小二乘TDQ求解器
     50 
     51     """LSTDQ Implementation with standard matrix solvers.
     52     #用矩阵的形式实现
     53     #算法根据文献的第五张图,如果矩阵A是满秩的,那么就用scipy的库来计算
     54     #如果不满秩,就用最小二乘的方法
     55     Uses the algorithm from Figure 5 of the LSPI paper. If the A matrix
     56     turns out to be full rank then scipy's standard linalg solver is used. If
     57     the matrix turns out to be less than full rank then least squares method
     58     will be used.
     59     #通常矩阵A的对角线值是小的正数值,这用来保证即使是很少的采样,矩阵A也能满秩,如果
     60     #不想要这样的前提,可以让前提条件值为0
     61     By default the A matrix will have its diagonal preconditioned with a small
     62     positive value. This will help to ensure that even with few samples the
     63     A matrix will be full rank. If you do not want the A matrix to be
     64     preconditioned then you can set this value to 0.
     65 
     66     Parameters前提条件值
     67     ----------
     68     precondition_value: float
     69         Value to set A matrix diagonals to. Should be a small positive number.
     70         If you do not want preconditioning enabled then set it 0.
     71     """
     72 
     73     def __init__(self, precondition_value=.1):#初始化
     74         """Initialize LSTDQSolver."""
     75         self.precondition_value = precondition_value#对前提条件值赋值
     76 
     77     def solve(self, data, policy):#求解函数
     78         """Run LSTDQ iteration.
     79 
     80         See Figure 5 of the LSPI paper for more information.
     81         """
     82         k = policy.basis.size()#k是特征phi向量的长度
     83         a_mat = np.zeros((k, k))#建立A矩阵,k行k列
     84         np.fill_diagonal(a_mat, self.precondition_value)#向矩阵A中填充前提条件值
     85         #说明前提条件值是用来保证矩阵是满秩的
     86 
     87         b_vec = np.zeros((k, 1))#b向量
     88 
     89         for sample in data:#对于data中的每一个采样进行循环
     90             phi_sa = (policy.basis.evaluate(sample.state, sample.action)
     91                       .reshape((-1, 1)))#通过basisfunction求出phi值
     92 
     93             if not sample.absorb:
     94                 best_action = policy.best_action(sample.next_state)#计算下一个状态下的最佳动作
     95                 phi_sprime = (policy.basis
     96                               .evaluate(sample.next_state, best_action)
     97                               .reshape((-1, 1)))#计算一个新的phi
     98             else:
     99                 phi_sprime = np.zeros((k, 1))
    100 
    101             a_mat += phi_sa.dot((phi_sa - policy.discount*phi_sprime).T)#计算a矩阵
    102             b_vec += phi_sa*sample.reward#计算b矩阵
    103 
    104         a_rank = np.linalg.matrix_rank(a_mat)
    105         if a_rank == k:#如果满秩
    106             w = scipy.linalg.solve(a_mat, b_vec)#求逆解出w值
    107         else:
    108             logging.warning('A matrix is not full rank. %d < %d', a_rank, k)
    109             w = scipy.linalg.lstsq(a_mat, b_vec)[0]
    110         return w.reshape((-1, ))#返回已经优化后的w值.
  • 相关阅读:
    Java中WebService实例
    Sublime Text 3 史上最性感的编辑器
    win2003的IIS無法使用,又一次安裝提示找不到iisadmin.mfl文件
    [Unity3D]Unity3D游戏开发之刀光剑影特效的实现
    fopen()函数
    UML学习(一)类图和对象图
    AfxMessageBox和MessageBox差别
    【课程分享】基于plusgantt的项目管理系统实战开发(Spring3+JDBC+RMI的架构、自己定义工作流)
    Android自己主动化測试之Monkeyrunner用法及实例
    Java工厂模式
  • 原文地址:https://www.cnblogs.com/lijiajun/p/5490041.html
Copyright © 2011-2022 走看看