zoukankan      html  css  js  c++  java
  • AStar 路径规划之初级二

    1、A* 中 g,h函数使用两点之间五次样条的弧长,并且根据需要进行微调。

    2、五次多项式的末状态为(l,0,0),所以中间过程两点进行连接的时候是水平的,最好的结果还是应该获得各个轨迹点后,再对该轨迹点进行重新拟合。或者在查找的过程中就已经考虑到中间的状态不是水平的。

    3、如果模型为低速模型,那这个轨迹应该是可以使用的。

    # -*- coding: utf-8 -*-
    """
    Created on Fri Dec 27 11:02:55 2019
    
    @author: leizhen.liu
    """
    from scipy.integrate import quad
    import numpy as np
    import matplotlib.pyplot as plt
    
    class VehicleState:
        def __init__(self,s,l,vpl,apl):
            self.s = s
            self.l = l
            self.vpl = vpl
            self.apl = apl
    
    
    class trajectoryCost:
        def __init__(self,startVehicleState,endVehicleState,totals,totall,detas,detal):
            self.startVehicleState = VehicleState(0,0,startVehicleState.vpl,startVehicleState.apl)
            self.endVehicleState = VehicleState((endVehicleState.s - startVehicleState.s)*detas,(endVehicleState.l - startVehicleState.l)*detal,endVehicleState.vpl,endVehicleState.apl)
            self.sorg = startVehicleState.s * detas
            self.lorg = startVehicleState.l * detal
            self.id = str(startVehicleState.s + startVehicleState.l * totals) + 'id'+ str(endVehicleState.s + endVehicleState.l * totals)
            self.matP = np.zeros((6,1)) 
            self.arcLength()
    
    
        def calQuintic(self):
            s = self.endVehicleState.s
            matS = np.mat([[1,0,0,0,0,0],
                             [0,1,0,0,0,0],
                             [0,0,2,0,0,0],
                             [1,s,s**2,s**3,s**4,s**5],
                             [0,1,2*s,3*s**2,4*s**3,5*s**4],
                             [0,0,2,6*s,12*s**2,20*s**3]])
        
            #ju zhen ni pan duan
            if np.linalg.det(matS)<0.001:
                self.matP = np.mat([0,0,0,0,0,0]).T
                return
        
    
            matL = np.mat([0,self.startVehicleState.vpl,self.startVehicleState.apl,
                             self.endVehicleState.l,self.endVehicleState.vpl,self.endVehicleState.apl])
    
            self.matP =  matS.I * matL.T 
                    
    
            
        def f(self,s):
            return  np.sqrt(1+ (self.matP[1] +2*self.matP[2]*s+3*self.matP[3]*s**2+4*self.matP[4]*s**3+5*self.matP[5]*s**4)**2)
            
        def arcLength(self):
            self.calQuintic()
            self.cost,err = quad(self.f,self.startVehicleState.s, self.endVehicleState.s)
            
        def show(self):
            s = list()
            l = list()
            print('stop =',self.endVehicleState.s)
            for i in np.arange(0,self.endVehicleState.s,0.3):
                s.append(i + self.sorg)
                sp = np.mat([1,i,i**2,i**3,i**4,i**5])
                l0 = sp*self.matP
    
                l1 = np.asarray(l0)
                l.append(l1[0][0] + self.lorg)
    
            plt.plot(s,l,'-')    
                
            
    # -*- coding: utf-8 -*-
    """
    Spyder Editor
    
    This is a temporary script file.
    """
    
    import numpy as  np
    import matplotlib.pyplot as plt
    import math
    import queue
    import time
    import copy
    import huchang
    
    MaxsSearch  = 5
    
    detaWidth = 0.3
    totalWidth = 6.0   #road width
    detaLength = 3.0   
    totalLength = 60.0 # road predict length
    vehicleHarfWidth = 0.8
    
    # s,l half width,half height /m
    obstacle = np.array([[40,4,0.3,0.3],
                         [16,1.5,0.6,0.3],#[30,2.5,0.3,0.3]
                         ])
    
    numberOfL = np.int(totalWidth /detaWidth)
    numberOfS = np.int(totalLength /detaLength)
    
    class StrPosition:
        def __init__(self,s,l,vpl =0.0,apl =0.0):
            self.s = s
            self.l = l
            self.vpl = vpl
            self.apl = apl
    
    
    class StrReachablePosition:
        def __init__(self,parentNode,reachNode):
            self.parentNode = parentNode
            self.reachNode = reachNode
            self.f =999
            self.g =999
            self.h =999
            
    
    def nodeCost(startPosition,endPosition,node2nodeCost):
        begin = startPosition.s + startPosition.l*numberOfS
        end = endPosition.s + endPosition.l*numberOfS
        if startPosition.s > endPosition.s:
            begin = copy.deepcopy(end)
            end = startPosition.s + startPosition.l*numberOfS
            Position =copy.deepcopy(startPosition)
            startPosition = copy.deepcopy(endPosition)
            endPosition = copy.deepcopy(Position)
            
        ikey = str(begin) +'id'+str(end)
    
        if startPosition.s == 2 and startPosition.l == 11:
            print('ikey',ikey)
            
        
        if node2nodeCost.get(ikey,0) == 0 :
            startState =  huchang.VehicleState(startPosition.s,startPosition.l,startPosition.vpl,startPosition.apl)
            endState = huchang.VehicleState(endPosition.s,endPosition.l,endPosition.vpl,endPosition.apl)
            traject = huchang.trajectoryCost(startState,endState,numberOfS,numberOfL,detaLength,detaWidth)
            node2nodeCost[ikey] = copy.deepcopy(traject)
    
            
        return node2nodeCost,node2nodeCost[ikey].cost
            
    def findInQueue(closeMap,node):
         ikey = node.reachNode.l * numberOfS + node.reachNode.s
         if closeMap.get(ikey,0) == 0:
             return False
         return True
    
    
    def refreshMap(node,openList):
          ikey = node.reachNode.l * numberOfS + node.reachNode.s
          openList[ikey] = node
          return openList
          
    def updateNode(node,openList,closeMap):
        #print("updateNode-------------",node.reachNode.l,node.reachNode.s)
        ikey = node.reachNode.l * numberOfS + node.reachNode.s
        if openList.get(ikey,0) != 0:
            nodet = openList[ikey]
            if(nodet.g > node.g):
                openList[ikey] = node
        else:
             if findInQueue(closeMap,node) == False:
                 openList[ikey] = node
        return copy.deepcopy(openList)
    
    def delMap(node,openList):
        ikey = node.reachNode.l * numberOfS + node.reachNode.s
        if openList.get(ikey,0) == 0  or ikey == 0:
            return openList
        del openList[ikey] 
        return openList
    
    
    def environmentShowGraph(environment,startIndex,endIndex):
    
        shape = environment.shape
        row = shape[0]
        col= shape[1]   
        os = list()
        ol = list()
        for rowi in range(startIndex,endIndex):
            for coli in range(col):
              x = np.array([0,col-1])
              y = np.array([rowi,rowi])
              x1 = np.array([coli,coli])
              y1 = np.array([0,(row -1)])
              plt.plot(x,y,'b-')
              plt.plot(x1,y1,'b-')
              if environment[rowi][coli] >0.1:
                  os.append(coli)
                  ol.append(rowi)
        plt.plot(os,ol,'o','r')
    
    
    def environmentShow(environment,startIndex,endIndex):
    
        shape = environment.shape
        col= shape[1]   
        os = list()
        ol = list()
        for rowi in range(startIndex,endIndex):
            for coli in range(col):
              x = np.array([0,totalLength])
              y = np.array([rowi*detaWidth,rowi*detaWidth])
              x1 = np.array([coli*detaLength,coli*detaLength])
              y1 = np.array([0,totalWidth])
              plt.plot(x,y,'b-')
              plt.plot(x1,y1,'b-')
              if environment[rowi][coli] >0.1:
                  os.append(coli*detaLength)
                  ol.append(rowi*detaWidth)
        plt.plot(os,ol,'o','r')
    
    
    
    def createObstacle(environment,obstacle,startIndex,endIndex):
        
        eshape = environment.shape
        erow = eshape[0]
        ecol= eshape[1]
        
        oshape = obstacle.shape
        for orowi in  range(oshape[0]):
    
            start_l = math.floor( (obstacle[orowi][1]-obstacle[orowi][2]) / detaWidth) - startIndexOfL
            end_l = math.ceil((obstacle[orowi][1]+obstacle[orowi][2]) / detaWidth) + startIndexOfL
            
            start_s = math.floor(( obstacle[orowi][0] -obstacle[orowi][3])/detaLength)
            end_s = math.ceil(( obstacle[orowi][0] + obstacle[orowi][3])/detaLength)
            #print("obstacle s range",start_s,end_s)
            for obsl_i in range(start_l,end_l+1):
                for obss_i in range(start_s,end_s+1):
                    if obsl_i>=0 and obsl_i < erow and obss_i >= 0 and obss_i< ecol:
                        environment[obsl_i][obss_i] = 1.0
                        
                
    def createStopPosition(x,startIndexaOfL,endIndexOfL):
        stopPosition = np.zeros((endIndexOfL - startIndexOfL,2))
        for index in range(endIndexOfL-startIndexaOfL):
            stopPosition[index][0] = x
            stopPosition[index][1] = index
        return stopPosition
    
    def getHCost(startPosition,endPosition,node2nodeCost):
        
         node2nodeCost,cost = nodeCost(startPosition,endPosition,node2nodeCost)
         cost0 = (startPosition.l - endPosition.l)*detaWidth - (startPosition.s - endPosition.s)*detaLength *0.1
         return node2nodeCost,cost + cost0 *0.1
    
    def getGCost(node,endPosition,node2nodeCost):
        node2nodeCost,cost = nodeCost(node.reachNode,endPosition,node2nodeCost)
        cost0 = (startPosition.l - endPosition.l)*detaWidth - (startPosition.s - endPosition.s)*detaLength *0.1
        return node2nodeCost, node.g + cost+0.01*cost0
    
    def getFCost(node):
        return node.g+node.h
    
    def checkcollision(environment,startPosition,endPosition):
    
        minl = min(startPosition.l,endPosition.l)
        maxl = max(startPosition.l,endPosition.l)
        mins = min(startPosition.s,endPosition.s)
        maxs = max(startPosition.s,endPosition.s)
        for s0 in range(mins,maxs+1):
            for l0 in range(minl,maxl+1):
                if environment[l0][s0] >0.1:
                    return False
        return True
                
    
    
    def reachPoints(openList,environment,lastNode,stopPosition,closeMap, startIndexOfL,endIndexOfL,node2nodeCost):
        print("reachPoints-------------")
        parentNode = lastNode.reachNode
        nexts = parentNode.s + 1
        reachNode = StrPosition(nexts,0)
        node = StrReachablePosition(parentNode,reachNode)
        
        shape = environment.shape
        node.parentNode = parentNode
        print("nexts,s of env",nexts , shape[1],numberOfS,numberOfL)
        maxsearchOfs = min(nexts+MaxsSearch,numberOfS)
        for s in range(nexts,maxsearchOfs):
            for l in range(startIndexOfL,endIndexOfL):
                reachNode.s = s
                reachNode.l =l
                if checkcollision(environment,parentNode,reachNode):
                    node.reachNode = reachNode
                    node2nodeCost, node.g = getGCost(lastNode,reachNode,node2nodeCost)
                    node2nodeCost, node.h = getHCost(reachNode,stopPosition,node2nodeCost)
                    node.f = getFCost(node)
                    openList = updateNode(node,openList,closeMap)  
        return openList,node2nodeCost
        
    def freshCloseList(node,openList,closeList,closeMap):
        print("freshCloseList-------------")
        ikey = node.reachNode.l * numberOfS + node.reachNode.s
        closeList.put(node)
        closeMap[ikey] =node.parentNode.l * numberOfS + node.parentNode.s
        return delMap(node,openList),closeList,closeMap    
         
     
    def getMinFFromOpenList(openList):
        print("getMinFFromOpenList-------------")
        #find min node
        start = StrPosition(0,0)
        end = StrPosition(0,0)
        noden = StrReachablePosition(start,end)
        if  len(openList) == 0:
            return noden,False
        minf =999999 
        for nodet in openList.items():
            if minf >nodet[1].f:
                minf = nodet[1].f
                noden = nodet[1]
                
        print("min",minf,noden.reachNode.s,noden.reachNode.l)       
        return noden,True
     
    
    
    def AStart(environment,startPosition,stopPosition,startIndexOfL,endIndexOfL):
        print("AStart-------------")
        node2NodeCost = dict()
        #put startnode to queue
        node = StrReachablePosition(startPosition,startPosition)
        node.g = 0
        node2NodeCost, node.h = getGCost(node,startPosition,node2NodeCost)
        node.f = getFCost(node)
        openList = dict()
        openList = refreshMap(node,openList)
        print("dict len",len(openList))
        print("reachPosition",node.reachNode.s,node.reachNode.l)
        time.sleep(1)
        closeList = queue.Queue()
        closeMap = dict()
       #search 
        while node.reachNode.s != stopPosition.s or  
            node.reachNode.l != stopPosition.l and len(openList)!=0:
            # update reachabe node 
            #print("node.reachNode.s,stopPosition.s,node.reachNode.l,stopPosition.l",node.reachNode.s,stopPosition.s,node.reachNode.l,stopPosition.l)
            openList,node2NodeCost = reachPoints(openList,environment,node,stopPosition,closeMap,startIndexOfL,endIndexOfL,node2NodeCost)
    
            ##-------------
            '''
            if node.reachNode.s ==  2 and node.reachNode.l == 7:
                for nodet  in openList.items():
                    print("openlist",nodet[0],nodet[1].reachNode.s,nodet[1].reachNode.l,nodet[1].f)
            '''
            ##----------------
            #time.sleep(5)
            node,flag = getMinFFromOpenList(openList)
            if flag == False:
                print("openlist null")
                return closeList,closeMap,node2NodeCost,False
            else:
               openList,closeList,closeMap = freshCloseList(node,openList,closeList,closeMap)
              ##-------------
            #print("closemaplen",len(closeMap))
            #for nodet2  in closeMap.items():
            #    print("closemap",nodet2[0],nodet2[1])
            ##----------------
                 
            if node.reachNode.s == stopPosition.s and node.reachNode.l == stopPosition.l:
                return closeList,closeMap,node2NodeCost,True
        print("openlist null",len(openList))
        return closeList, closeMap,node2NodeCost,False     
      
        
    def getTrack(closeMap,startPosition,stopPosition):
        startkey =  startPosition.l * numberOfS + startPosition.s
        endkey = stopPosition.l * numberOfS + stopPosition.s
        
        ikey = closeMap[endkey]
        slist = list()
        llist = list()
        track = queue.deque() 
        slist.append(stopPosition.s)
        llist.append(stopPosition.l)
        track.append(np.int(endkey))
        #print('s,l',endkey,stopPosition.s,stopPosition.l)
        
        while ikey != startkey:
            s =  np.int(ikey % numberOfS)
            l =  np.int((ikey - s) / numberOfS)
            #print('s,l',ikey,s,l)
            track.append(np.int(ikey))
            ikey = closeMap[ikey]
            slist.append(s)
            llist.append(l)
            
        slist.append(startPosition.s)
        llist.append(startPosition.l)
        track.append(np.int(startkey))
        #print('s,l',startkey,startPosition.s,startPosition.l)
        
        return copy.deepcopy(slist),copy.deepcopy(llist),copy.deepcopy(track)
    
    def showTrajackPointGraph(slist,llist):
        plt.plot(slist,llist,'o')
        
    
    def showTrajackPoint(slist,llist):
        slistn  = list()
        llistn = list()
        for s in slist:
            slistn.append(s*detaLength)
            print('s',s*detaLength)
        for l in llist:
            llistn.append(l*detaWidth)
            print('l',l*detaWidth)
            
        plt.plot(slistn,llistn,'o')   
    
    def showTrack(track,node2nodeCost):
        start = track.pop()
        print('node2nodeCost',len(node2nodeCost))
        while len(track) !=0:
            end = track.pop()
            key = str(start) + 'id' + str(end)
            #print('key',start,end) 
            trackvalue = node2nodeCost[key]
            trackvalue.show()
            start = end
        
        
    def showCloseList(closeList):
        while not closeList.empty():
            node = closeList.get()
            plt.plot(node.reachNode.s*detaLength,node.reachNode.l*detaWidth,'r*')
    
    def showCloseListGraph(closeList):
        while not closeList.empty():
            node = closeList.get()
            plt.plot(node.reachNode.s,node.reachNode.l,'r*')   
        
    
    if __name__ == '__main__':
    
        # vehicle to road region  of l 
    
        startIndexOfL = np.int(vehicleHarfWidth / detaWidth)
        endIndexOfL = numberOfL - np.int(vehicleHarfWidth / detaWidth)
         
        environment = np.zeros((numberOfL,numberOfS))
        createObstacle(environment,obstacle,startIndexOfL,endIndexOfL)
        environmentShow(environment,startIndexOfL,endIndexOfL)
        
        #simple end a*--------------------------------------------
    
        startPosition = StrPosition(0,12,0.1,0)
        stopPosition = StrPosition(18,8)
        closeList,closeMap,node2NodeCost,Flag = AStart(environment,startPosition,stopPosition,startIndexOfL,endIndexOfL)
        
        
        
        if Flag == True:
            slist,llist,track = getTrack(closeMap,startPosition,stopPosition)
            
            for t in track:
                print('result track',t)
            
            
            showTrajackPoint(slist,llist)
            showTrack(track,node2NodeCost)
        else:
            print("no route-----------------")
            showCloseList(closeList)
    
        #multi end a*------------------------------------
        '''
        stopPosition = createStopPosition(numberOfS-1,startIndexOfL,endIndexOfL)  
        startPosition = StrPosition(0,9)
    
        shapeOfEnd = stopPosition.shape
        for i  in  range(0,shapeOfEnd[0]):
            stop = StrPosition(stopPosition[i][0],stopPosition[i][1])
            closeList,closeMap,node2NodeCost,Flag = AStart(environment,startPosition,stop,startIndexOfL,endIndexOfL)
        
            if Flag == True:
                slist,llist,track = getTrack(closeMap,startPosition,stop)   
                for t in track:
                    print('result track',t)
            
                showTrajackPoint(slist,llist)
                showTrack(track,node2NodeCost)
            else:
                print("no route-----------------")
                showCloseList(closeList)
            #showCloseList(closeList)   
        '''
    
        
    
    
    
    
              
        
            
        

  • 相关阅读:
    219. Contains Duplicate II
    189. Rotate Array
    169. Majority Element
    122. Best Time to Buy and Sell Stock II
    121. Best Time to Buy and Sell Stock
    119. Pascal's Triangle II
    118. Pascal's Triangle
    88. Merge Sorted Array
    53. Maximum Subarray
    CodeForces 359D Pair of Numbers (暴力)
  • 原文地址:https://www.cnblogs.com/kabe/p/12112373.html
Copyright © 2011-2022 走看看