zoukankan      html  css  js  c++  java
  • 遗传算法

      遗传算法(genetic algorithm)是进化算法的一种。来源于达尔文的生物进化学——“物竞天择,适者生存”。一个种群在繁衍的过程中,通过交叉繁衍和个体变异产生了新的一代。新一代中有的个体能适应当前环境很好的生存从而继续繁衍,而有的个体因无法适应环境而被环境淘汰。

    如何用计算机表示?

      一个种群中的每个个体,都可以用DNA来表示,DNA的计算机表示可以用固定长度的二进制码表示,如010101。标准的交叉繁衍过程可以分别使用两个个体一半的DNA序列拼接而成。如:

    000111,010101 ——> 000101

      个体变异过程可以通过对生成的子个体的某个位置的DNA进行变化,如:

    000101 ——> 000111

      新生成的个体对环境的适应性,根据实际任务的目标函数而定,以DNA为自变量得到因变量目标函数的值。

    能解决什么问题?

      遗传算法通常认为是一种搜索算法,在以DNA为变量定义的解空间中,根据目标函数逼近近似最优解的过程。使用遗传算法可以解决旅行商(TSP)问题、最小生成树问题等。

    主要特点

      其主要特点是通过生物进化规律定义了一种搜索策略,不存在求导和函数连续性的限定;具有内在的隐并行性和更好的全局寻优能力;采用概率化的寻优方法,不需要确定的规则就能自动获取和指导优化的搜索空间,自适应地调整搜索方向。

    实践

      使用遗传算法解决TSP问题,这里的交叉算子和变异算子与标准遗传算法不同,因为得保证路径不能重复,这样做避免了无效个体的产生,且以较高概率搜索解空间中各个可行解。

    """
    Visualize Genetic Algorithm to find the shortest path for travel sales problem.
    Visit my tutorial website for more: https://morvanzhou.github.io/tutorials/
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    N_CITIES = 20  # DNA size
    CROSS_RATE = 0.1
    MUTATE_RATE = 0.02
    POP_SIZE = 500
    N_GENERATIONS = 500
    
    
    class GA(object):
        def __init__(self, DNA_size, cross_rate, mutation_rate, pop_size, ):
            self.DNA_size = DNA_size
            self.cross_rate = cross_rate
            self.mutate_rate = mutation_rate
            self.pop_size = pop_size
    
            self.pop = np.vstack([np.random.permutation(DNA_size) for _ in range(pop_size)])
    
        def translateDNA(self, DNA, city_position):     # get cities' coord in order
            line_x = np.empty_like(DNA, dtype=np.float64)
            line_y = np.empty_like(DNA, dtype=np.float64)
            for i, d in enumerate(DNA):
                city_coord = city_position[d]
                line_x[i, :] = city_coord[:, 0]
                line_y[i, :] = city_coord[:, 1]
            return line_x, line_y
    
        def get_fitness(self, line_x, line_y):
            total_distance = np.empty((line_x.shape[0],), dtype=np.float64)
            for i, (xs, ys) in enumerate(zip(line_x, line_y)):
                total_distance[i] = np.sum(np.sqrt(np.square(np.diff(xs)) + np.square(np.diff(ys))))
            fitness = np.exp(self.DNA_size * 2 / total_distance)
            return fitness, total_distance
    
        def select(self, fitness):
            idx = np.random.choice(np.arange(self.pop_size), size=self.pop_size, replace=True, p=fitness / fitness.sum())
            return self.pop[idx]
    
        def crossover(self, parent, pop):
            if np.random.rand() < self.cross_rate:
                i_ = np.random.randint(0, self.pop_size, size=1)                        # select another individual from pop
                cross_points = np.random.randint(0, 2, self.DNA_size).astype(np.bool)   # choose crossover points
                keep_city = parent[~cross_points]                                       # find the city number
                swap_city = pop[i_, np.isin(pop[i_].ravel(), keep_city, invert=True)]
                parent[:] = np.concatenate((keep_city, swap_city))
            return parent
    
        def mutate(self, child):
            for point in range(self.DNA_size):
                if np.random.rand() < self.mutate_rate:
                    swap_point = np.random.randint(0, self.DNA_size)
                    swapA, swapB = child[point], child[swap_point]
                    child[point], child[swap_point] = swapB, swapA
            return child
    
        def evolve(self, fitness):
            pop = self.select(fitness)
            pop_copy = pop.copy()
            for parent in pop:  # for every parent
                child = self.crossover(parent, pop_copy)
                child = self.mutate(child)
                parent[:] = child
            self.pop = pop
    
    
    class TravelSalesPerson(object):
        def __init__(self, n_cities):
            self.city_position = np.random.rand(n_cities, 2)
            plt.ion()
    
        def plotting(self, lx, ly, total_d):
            plt.cla()
            plt.scatter(self.city_position[:, 0].T, self.city_position[:, 1].T, s=100, c='k')
            plt.plot(lx.T, ly.T, 'r-')
            plt.text(-0.05, -0.05, "Total distance=%.2f" % total_d, fontdict={'size': 20, 'color': 'red'})
            plt.xlim((-0.1, 1.1))
            plt.ylim((-0.1, 1.1))
            plt.pause(0.01)
    
    
    ga = GA(DNA_size=N_CITIES, cross_rate=CROSS_RATE, mutation_rate=MUTATE_RATE, pop_size=POP_SIZE)
    
    env = TravelSalesPerson(N_CITIES)
    for generation in range(N_GENERATIONS):
        lx, ly = ga.translateDNA(ga.pop, env.city_position)
        fitness, total_distance = ga.get_fitness(lx, ly)
        ga.evolve(fitness)
        best_idx = np.argmax(fitness)
        print('Gen:', generation, '| best fit: %.2f' % fitness[best_idx],)
    
        env.plotting(lx[best_idx], ly[best_idx], total_distance[best_idx])
    
    plt.ioff()
    plt.show()

    参考文献

    莫烦python

  • 相关阅读:
    MySql msi安装
    C# TextBox文本内容选中
    SQL 删除时间最靠前的几条数据
    Layui表格工具栏绑定事件失效问题
    Layui我提交表单时,table.reload(),表格会请求2次,是为什么?按下面的做
    table 中数据行循环滚动
    html 3D反转效果
    网页电子表数字样式
    power tool 强制撤销
    GHOST -ntexact 正常还原
  • 原文地址:https://www.cnblogs.com/majiale/p/9678262.html
Copyright © 2011-2022 走看看