zoukankan      html  css  js  c++  java
  • python遗传算法实现数据拟合

    python据说功能强大,触角伸到各个领域,网上搜了一下其科学计算和工程计算能力也相当强,具备各种第三方包,除了性能软肋外,其他无可指摘,甚至可以同matlab等专业工具一较高下。

    从网上找了一个使用遗传算法实现数据拟合的例子学习了一下,确实Python相当贴合自然语言,终于编程语言也能说人话了,代码整体简洁、优雅。。

    代码功能:给出一个隐藏函数 例如 z=x^2+y^2,生成200个数据,利用这200个数据,使用遗传算法猜测这些数据是什么公式生成的。 (说的太直白,一点都不高大上)

    代码如下:

      1 # coding=utf-8
      2 from random import random, randint, choice,uniform
      3 from copy import deepcopy
      4 import numpy as np
      5 import matplotlib.pyplot as plt
      6 
      7 from random import random, randint, choice
      8 from copy import deepcopy
      9 import numpy as np
     10 
     11 # 运算类
     12 class fwrapper:
     13     def __init__(self, function, childcount, name):
     14         self.function = function
     15         self.childcount = childcount
     16         self.name = name
     17 
     18 # 节点类
     19 class node:
     20     def __init__(self, fw, children):
     21         self.function = fw.function
     22         self.name = fw.name
     23         self.children = children
     24 #将inp指定的运算符作用到子节点上
     25     def evaluate(self, inp):
     26         # 循环调用子节点的子节点的子节点....的evaluate方法
     27         results = [n.evaluate(inp) for n in self.children]
     28         # 返回运算结果
     29         return self.function(results)
     30 #打印本节点及所属节点的操作运算符
     31     def display(self, indent=0):
     32         print(' ' * indent) + self.name
     33         for c in self.children:
     34             c.display(indent + 1)
     35 
     36 #参数节点类,x+y 其中x,y都是参数节点
     37 class paramnode:
     38     def __init__(self, idx):
     39         self.idx = idx
     40 # evaluate方法返回paramnode节点值本身
     41     def evaluate(self, inp):
     42         return inp[self.idx]
     43 
     44     def display(self, indent=0):
     45         print '%sp%d' % (' ' * indent, self.idx)
     46 
     47 # 常数节点
     48 class constnode:
     49     def __init__(self, v):
     50         self.v = v
     51 
     52     def evaluate(self, inp):
     53         return self.v
     54 
     55     def display(self, indent=0):
     56         print '%s%d' % (' ' * indent, self.v)
     57 
     58 # 操作运算符类
     59 class opera:
     60     # 使用前面定义的fwrapper类生产常用的加减乘除运算,第一个参数是本运算执行方式,第二个参数是本运算接受的参数个数,第三个参数是本运算名称
     61     addw = fwrapper(lambda l: l[0] + l[1], 2, 'add')
     62     subw = fwrapper(lambda l: l[0] - l[1], 2, 'subtract')
     63     mulw = fwrapper(lambda l: l[0] * l[1], 2, 'multiply')
     64 
     65     def iffunc(l):
     66         if l[0] > 0:
     67             return l[1]
     68         else:
     69             return l[2]
     70     #定义if运算
     71     ifw = fwrapper(iffunc, 3, 'if')
     72 
     73     def isgreater(l):
     74         if l[0] > l[1]:
     75             return 1
     76         else:
     77             return 0
     78     #定义greater运算
     79     gtw = fwrapper(isgreater, 2, 'isgreater')
     80     #构建运算符集合
     81     flist = [addw, mulw, ifw, gtw, subw]
     82 
     83     #使用node,paramnode,fwrapper构建一个example
     84     def exampletree(self):
     85         return node(self.ifw, [node(self.gtw, [paramnode(0), constnode(3)]), node(self.addw, [paramnode(1), constnode(5)]),
     86                           node(self.subw, [paramnode(1), constnode(2)]), ])
     87 
     88 
     89     # 构建一颗随机运算数,pc为参数(分叉)个数,maxdepth为树的深度,fpr为运算符个数在运算符加节点总数中所占比例,ppr为参数个数在参数加常数个数总数中所占的比例
     90     def makerandomtree(self,pc, maxdepth=4, fpr=0.5, ppr=0.6):
     91         if random() < fpr and maxdepth > 0:
     92             f = choice(self.flist)
     93             # 递归调用makerandomtree实现子节点的创建
     94             children = [self.makerandomtree(pc, maxdepth - 1, fpr, ppr) for i in range(f.childcount)]
     95             return node(f, children)
     96         elif random() < ppr:
     97             return paramnode(randint(0, pc - 1))
     98         else:
     99             return constnode(randint(0, 10))
    100 
    101 
    102     #变异,变异概率probchange=0.1
    103     def mutate(self,t, pc, probchange=0.1):
    104         # 变异后返回一颗随机树
    105         if random() < probchange:
    106             return self.makerandomtree(pc)
    107         else:
    108             result = deepcopy(t)
    109             # 递归调用,给其子节点变异的机会
    110             if isinstance(t, node):
    111                 result.children = [self.mutate(c, pc, probchange) for c in t.children]
    112             return result
    113 
    114     #交叉
    115     def crossover(self,t1, t2, probswap=0.7, top=1):
    116         # 如果符合交叉概率,就将t2的值返回,实现交叉;
    117         if random() < probswap and not top:
    118             return deepcopy(t2)
    119         else:
    120             #如果本层节点未实现交配,递归询问子节点是否符合交配条件
    121             #首先使用deepcopy保存本节点
    122             result = deepcopy(t1)
    123             if isinstance(t1, node) and isinstance(t2, node):
    124                 #依次递归询问t1下的各子孙节点交配情况,交配对象为t2的各子孙;t1,t2家族同辈交配
    125                 result.children = [self.crossover(c, choice(t2.children), probswap, 0) for c in t1.children]
    126             return result
    127 
    128     # random2.display()
    129     # muttree=mutate(random2,2)
    130     # muttree.display()
    131     # cross=crossover(random1,random2)
    132     # cross.display()
    133 
    134     #设置一个隐藏函数,也就是原始函数
    135     def hiddenfunction(self,x, y):
    136         return x ** 2+ y**2
    137 
    138     #依照隐藏函数,生成坐标数据
    139     def buildhiddenset(self):
    140             rows = []
    141             for i in range(200):
    142                 x = randint(0, 10)
    143                 x=uniform(-1,1)
    144                 y = randint(0, 10)
    145                 y=uniform(-1,1)
    146                 rows.append([x, y, self.hiddenfunction(x, y)])
    147             print("rows:",rows)
    148             return rows
    149 
    150     #拟合成绩函数,判定拟合函数(实际是一颗图灵树)贴近原始函数的程度
    151     def scorefunction(self,tree, s):
    152         dif = 0
    153         # print("tree:",tree)
    154         # print("s:",s)
    155         for data in s:
    156             # print("data[0]:",data[0])
    157             # print("data[1]:",data[1])
    158             v = tree.evaluate([data[0],data[1]])
    159             #累加每个数据的绝对值偏差
    160             dif += abs(v - data[2])
    161         return dif
    162 
    163     #返回一个成绩判定函数rankfunction的句柄
    164     def getrankfunction(self,dataset):
    165         #此函数调用拟合成绩函数,并对成绩排序,返回各个种群的成绩
    166         def rankfunction(population):
    167             scores = [(self.scorefunction(t, dataset), t) for t in population]
    168             scores.sort()
    169             return scores
    170         return rankfunction
    171 
    172     # hiddenset=buildhiddenset()
    173     # scorefunction(random2,hiddenset)
    174     # scorefunction(random1,hiddenset)
    175 
    176     def evolve(self,pc, popsize, rankfunction, maxgen=500, mutationrate=0.1, breedingrate=0.4, pexp=0.7, pnew=0.05):
    177         #轮盘算法
    178         def selectindex():
    179             return int(np.log(random()) / np.log(pexp))
    180         #使用随机树生成第一代各种群
    181         population = [self.makerandomtree(pc) for i in range(popsize)]
    182         #计算每一代各种群的成绩,
    183         for i in range(maxgen):
    184             scores = rankfunction(population)
    185             #打印历代最好成绩
    186             print('the best score in genneration ',i,':',scores[0][0])
    187             #成绩完全吻合原函数的话,退出函数
    188             if scores[0][0] == 0:
    189                 break
    190             #创建新一代各种群,成绩前两名的直接进入下一代
    191             newpop = [scores[0][1], scores[1][1]]
    192             while len(newpop) < popsize:
    193                 #pnew为引进随机种群概率,未达此概率的,使用原种群的交配、变异生成新种群
    194                 if random() > pnew:
    195                     newpop.append(
    196                         self.mutate(self.crossover(scores[selectindex()][1], scores[selectindex()][1], probswap=breedingrate), pc,
    197                                probchange=mutationrate))
    198                 #引入随机种群
    199                 else:
    200                     newpop.append(self.makerandomtree(pc))
    201             population = newpop
    202             #打印历代最好种群
    203             # scores[0][1].display()
    204         return scores[0][1]
    205 
    206 
    207 
    208 def main(argv):
    209     e=opera()
    210     def exampletree():
    211         return node(e.ifw,[node(e.gtw,[paramnode(0),constnode(3)]),node(e.addw,[paramnode(1),constnode(5)]),node(e.subw,[paramnode(1),constnode(2)])])
    212 
    213 
    214     # random1=e.makerandomtree(2)
    215     # random1.evaluate([7,1])
    216     # random1.evaluate([2,4])
    217     # random2=e.makerandomtree(2)
    218     # random2.evaluate([5,3])
    219     # random2.evaluate([5,20])
    220     # random1.display()
    221     # random2.display()
    222 
    223     # exampletree = e.exampletree()
    224     # exampletree.display()
    225     # print(exampletree.evaluate([6, 1]))
    226     # print('after evaluate:')
    227     # exampletree.display()
    228     # exampletree.evaluate([2, 3])
    229     # exampletree.evaluate([5, 3])
    230     # exampletree.display()
    231 
    232     a=opera()
    233     row2=a.buildhiddenset()
    234     # fig=plt.figure()
    235     # ax=fig.add_subplot(1,1,1)
    236     # plt.plot(np.random.randn(1000).cumsum())
    237     # plt.show()
    238 
    239 
    240 
    241     from mpl_toolkits.mplot3d import Axes3D
    242     fig = plt.figure()
    243     ax = fig.add_subplot(111, projection='3d')
    244     X = [1, 1, 2, 2]
    245     Y = [3, 4, 4, 3]
    246     Z = [1, 2, 1, 1]
    247     rx=[]
    248     ry=[]
    249     rz=[]
    250     for i in row2:
    251         rx.append(i[0])
    252         ry.append(i[1])
    253         rz.append(i[2])
    254 
    255     ax.plot_trisurf(rx, ry, rz)
    256     rz2=[]
    257     rf = a.getrankfunction(row2)
    258     final = a.evolve(2, 100, rf, mutationrate=0.2, breedingrate=0.1, pexp=0.7, pnew=0.1,maxgen=500)
    259     print('__________________is it?_________________________')
    260     final.display()
    261     for j in range(0,len(rx)):
    262         rz2.append(final.evaluate([rx[j],ry[j]]))
    263     fig2 = plt.figure()
    264     ax2 = fig2.add_subplot(111, projection='3d')
    265     ax2.plot_trisurf(rx, ry, rz2)
    266 
    267     plt.show()
    268 
    269 
    270     # print(rf)
    271     # final = a.evolve(2, 500, rf, mutationrate=0.2, breedingrate=0.1, pexp=0.7, pnew=0.1)
    272     # print("final:",final)
    273     # print(final.evaluate([1,8]))
    274     # print(final.evaluate([2,9]))
    275 
    276 
    277 
    278 
    279 
    280 if __name__=="__main__":

    281     main(0) 

    看懂不一定写的出来,这是这次写这个程序最大的体会, 得定时拿出来复习复习。

  • 相关阅读:
    Ubuntu下面MySQL的参数文件my.cnf浅析
    Ubuntu下创建XFS文件系统的LVM
    Linux LVM学习总结——Insufficient Free Extents for a Logical Volume
    SQL Server中通用数据库角色权限处理
    Key Lookup开销过大导致聚集索引扫描
    SQL Server OPTION (OPTIMIZE FOR UNKNOWN) 测试总结
    ERROR 1071 (42000): Specified key was too long; max key length is 767 bytes
    一次存储过程参数嗅探定位流程总结
    ORACLE如何检查找出损坏索引(Corrupt Indexes)
    MySQL索引扩展(Index Extensions)学习总结
  • 原文地址:https://www.cnblogs.com/javajava/p/4988706.html
Copyright © 2011-2022 走看看