zoukankan      html  css  js  c++  java
  • 遗传算法(三)—— 旅行商问题TSP

    遗传算法 (GA) 算法最主要的就是我们要想明白什么是他的 DNA 和怎么样对个体进行评估 (他们的 Fitness).

    Fitness和DNA

    这次的编码 DNA 方式又不一样, 我们可以尝试对每一个城市有一个 ID, 那经历的城市顺序就是按 ID 排序咯. 比如说商人要经过3个城市, 我们就有

    • 0-1-2
    • 0-2-1
    • 1-0-2
    • 1-2-0
    • 2-0-1
    • 2-1-0

    这6种排列方式. 每一种排列方式我们就能把它当做一种 DNA 序列, 用 numpy 产生这种 DNA 序列的方式很简单.

    >>> np.random.permutation(3)
    # array([1, 2, 0])

    计算 fitness 的时候, 我们只要将 DNA 中这几个城市连成线, 计算一下总路径的长度, 根据长度, 我们定下规则, 越短的总路径越好, 下面的 fitness0 就用来计算 fitness 啦. 因为越短的路径我们更要价大幅度选择, 所以这里我用到了 fitness1 这种方式.

    fitness0 = 1/total_distance
    fitness1 = np.exp(1/total_distance)

    交叉和变异

    我们要注意的是在 crossover 和 mutate 的时候有一点点不一样, 因为对于路径点, 我们不能随意变化. 比如 如果按平时的 crossover, 可能会是这样的结果:

    p1=[0,1,2,3] (爸爸)

    p2=[3,2,1,0] (妈妈)

    cp=[m,b,m,b] (交叉点, m: 妈妈, b: 爸爸)

    c1=[3,1,1,3] (孩子)

    那么这样的 c1 要经过两次城市 3, 两次城市1, 而没有经过 2, 0. 显然不行. 所以我们 crossover 以及 mutation 都要换一种方式进行. 其中一种可行的方式是这样. 同样是上面的例子.

    p1=[0,1,2,3] (爸爸)

    cp=[_,b,_,b] (选好来自爸爸的点)

    c1=[1,3,_,_] (先将爸爸的点填到孩子的前面)

    此时除开来自爸爸的 1, 3. 还有0, 2 两个城市, 但是0,2 的顺序就按照妈妈 DNA 的先后顺序排列. 也就是 p2=[3,2,1,0] 的 0, 2 两城市在 p2 中是先有 2, 再有 0. 所以我们就按照这个顺序补去孩子的 DNA.

    c1=[1,3,2,0]

    按照这样的方式, 我们就能成功避免在 crossover 产生的问题: 访问多次通过城市的问题了. 用 Python 的写法很简单.

    if np.random.rand() < self.cross_rate:
        i_ = np.random.randint(0, self.pop_size, size=1)                        # select another individual from pop
        cross_points = np.random.randint(0, 2, self.DNA_size).astype(np.bool)   # choose crossover points
        keep_city = parent[cross_points]                                       # find the city number
        swap_city = pop[i_, np.isin(pop[i_].ravel(), keep_city, invert=True)]   # 找到与爸爸不同的城市
        parent[:] = np.concatenate((keep_city, swap_city))

    在 mutate 的时候, 也是找到两个不同的 DNA 点, 然后交换这两个点就好了.

    for point in range(self.DNA_size):
        if np.random.rand() < self.mutate_rate:
            swap_point = np.random.randint(0, self.DNA_size)
            swapA, swapB = child[point], child[swap_point]
            child[point], child[swap_point] = swapB, swapA

    完整代码:

    """
    Visualize Genetic Algorithm to find the shortest path for travel sales problem.
    Visit my tutorial website for more: https://morvanzhou.github.io/tutorials/
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    N_CITIES = 20  # DNA size
    CROSS_RATE = 0.1
    MUTATE_RATE = 0.02
    POP_SIZE = 500
    N_GENERATIONS = 500
    
    
    class GA(object):
        def __init__(self, DNA_size, cross_rate, mutation_rate, pop_size, ):
            self.DNA_size = DNA_size
            self.cross_rate = cross_rate
            self.mutate_rate = mutation_rate
            self.pop_size = pop_size
    
            self.pop = np.vstack([np.random.permutation(DNA_size) for _ in range(pop_size)])
    
        def translateDNA(self, DNA, city_position):     # get cities' coord in order
            line_x = np.empty_like(DNA, dtype=np.float64)
            line_y = np.empty_like(DNA, dtype=np.float64)
            for i, d in enumerate(DNA):
                city_coord = city_position[d]
                line_x[i, :] = city_coord[:, 0]
                line_y[i, :] = city_coord[:, 1]
            return line_x, line_y
    
        def get_fitness(self, line_x, line_y):
            total_distance = np.empty((line_x.shape[0],), dtype=np.float64)
            for i, (xs, ys) in enumerate(zip(line_x, line_y)):
                total_distance[i] = np.sum(np.sqrt(np.square(np.diff(xs)) + np.square(np.diff(ys))))
            fitness = np.exp(self.DNA_size * 2 / total_distance)
            return fitness, total_distance
    
        def select(self, fitness):
            idx = np.random.choice(np.arange(self.pop_size), size=self.pop_size, replace=True, p=fitness / fitness.sum())
            return self.pop[idx]
    
        def crossover(self, parent, pop):
            if np.random.rand() < self.cross_rate:
                i_ = np.random.randint(0, self.pop_size, size=1)                        # select another individual from pop
                cross_points = np.random.randint(0, 2, self.DNA_size).astype(np.bool)   # choose crossover points
                keep_city = parent[~cross_points]                                       # find the city number
                swap_city = pop[i_, np.isin(pop[i_].ravel(), keep_city, invert=True)]
                parent[:] = np.concatenate((keep_city, swap_city))
            return parent
    
        def mutate(self, child):
            for point in range(self.DNA_size):
                if np.random.rand() < self.mutate_rate:
                    swap_point = np.random.randint(0, self.DNA_size)
                    swapA, swapB = child[point], child[swap_point]
                    child[point], child[swap_point] = swapB, swapA
            return child
    
        def evolve(self, fitness):
            pop = self.select(fitness)
            pop_copy = pop.copy()
            for parent in pop:  # for every parent
                child = self.crossover(parent, pop_copy)
                child = self.mutate(child)
                parent[:] = child
            self.pop = pop
    
    
    class TravelSalesPerson(object):
        def __init__(self, n_cities):
            self.city_position = np.random.rand(n_cities, 2)
            plt.ion()
    
        def plotting(self, lx, ly, total_d):
            plt.cla()
            plt.scatter(self.city_position[:, 0].T, self.city_position[:, 1].T, s=100, c='k')
            plt.plot(lx.T, ly.T, 'r-')
            plt.text(-0.05, -0.05, "Total distance=%.2f" % total_d, fontdict={'size': 20, 'color': 'red'})
            plt.xlim((-0.1, 1.1))
            plt.ylim((-0.1, 1.1))
            plt.pause(0.01)
    
    
    ga = GA(DNA_size=N_CITIES, cross_rate=CROSS_RATE, mutation_rate=MUTATE_RATE, pop_size=POP_SIZE)
    
    env = TravelSalesPerson(N_CITIES)
    for generation in range(N_GENERATIONS):
        lx, ly = ga.translateDNA(ga.pop, env.city_position)
        fitness, total_distance = ga.get_fitness(lx, ly)
        ga.evolve(fitness)
        best_idx = np.argmax(fitness)
        print('Gen:', generation, '| best fit: %.2f' % fitness[best_idx],)
    
        env.plotting(lx[best_idx], ly[best_idx], total_distance[best_idx])
    
    plt.ioff()
    plt.show()

    参考链接:莫烦PYTHON-旅行商问题(Travel Sales Problem)

  • 相关阅读:
    golang动态加载原生代码思路
    boltdb的实现和改进
    如何保证数据掉电不损坏?
    msgpack库的神奇用法
    消息推送与同步协议的思考
    cassandra的gc调优
    tcp链接断开的探测
    加速和监控国际网络
    轻松逃脱某防火墙对ss的探测
    java文件操作之war压缩解压
  • 原文地址:https://www.cnblogs.com/lfri/p/12240751.html
Copyright © 2011-2022 走看看