Card Problem
Card Problem 要求将给定的10个数1~10分成A、B两个集合,要求其中一个集合的和为36,另一个集合的乘积为360。
其中 genetic.py 完整代码修改后如下:
import random
import statistics
import sys
import time
def _generate_parent(length, geneSet, get_fitness):
genes = []
while len(genes) < length:
sampleSize = min(length - len(genes), len(geneSet))
genes.extend(random.sample(geneSet, sampleSize))
fitness = get_fitness(genes)
return Chromosome(genes, fitness)
def _mutate(parent, geneSet, get_fitness):
childGenes = parent.Genes[:]
index = random.randrange(0, len(parent.Genes))
newGene, alternate = random.sample(geneSet, 2)
childGenes[index] = alternate if newGene == childGenes[index] else newGene
fitness = get_fitness(childGenes)
return Chromosome(childGenes, fitness)
def _mutate_custom(parent, custom_mutate, get_fitness):
childGenes = parent.Genes[:]
custom_mutate(childGenes)
fitness = get_fitness(childGenes)
return Chromosome(childGenes, fitness)
def get_best(get_fitness, targetLen, optimalFitness, geneSet, display,
custom_mutate=None):
if custom_mutate is None:
def fnMutate(parent):
return _mutate(parent, geneSet, get_fitness)
else:
def fnMutate(parent):
return _mutate_custom(parent, custom_mutate, get_fitness)
def fnGenerateParent():
return _generate_parent(targetLen, geneSet, get_fitness)
for improvement in _get_improvement(fnMutate, fnGenerateParent):
display(improvement)
if not optimalFitness > improvement.Fitness:
return improvement
def _get_improvement(new_child, generate_parent):
bestParent = generate_parent()
yield bestParent
while True:
child = new_child(bestParent)
if bestParent.Fitness > child.Fitness:
continue
if not child.Fitness > bestParent.Fitness:
bestParent = child
continue
yield child
bestParent = child
class Chromosome:
def __init__(self, genes, fitness):
self.Genes = genes
self.Fitness = fitness
class Benchmark:
@staticmethod
def run(function):
timings = []
stdout = sys.stdout
for i in range(100):
sys.stdout = None
startTime = time.time()
function()
seconds = time.time() - startTime
sys.stdout = stdout
timings.append(seconds)
mean = statistics.mean(timings)
if i < 10 or i % 10 == 9:
print("{} {:3.2f} {:3.2f}".format(
1 + i, mean,
statistics.stdev(timings, mean) if i > 1 else 0))
上面这段代码,最大的改变就是对于get_best过程的改变。由于本次的问题如果每次只更新一个的话,就会长时间得不到最优解,因此本次代码需要自定义一个突变函数,因此这个过程get_best中,需要对时候需要使用自定义的突变函数进行判断并调用对应的突变函数。后续过程稍微做了一些调整,但是目的不变。
此外 cardTests.py 的完整代码如下:
import datetime
import functools
import operator
import random
import unittest
import genetic
def get_fitness(genes):
group1Sum = sum(genes[0:5])
group2Product = functools.reduce(operator.mul, genes[5:10])
duplicateCount = (len(genes) - len(set(genes)))
return Fitness(group1Sum, group2Product, duplicateCount)
def display(candidate, startTime):
timeDiff = datetime.datetime.now() - startTime
print("{} - {} {} {}".format(
', '.join(map(str, candidate.Genes[0:5])),
', '.join(map(str, candidate.Genes[5:10])),
candidate.Fitness,
timeDiff))
def mutate(genes, geneset):
if len(genes) == len(set(genes)):
count = random.randint(1, 4)
while count > 0:
count -= 1
indexA, indexB = random.sample(range(len(genes)), 2)
genes[indexA], genes[indexB] = genes[indexB], genes[indexA]
else:
indexA = random.randrange(0, len(genes))
indexB = random.randrange(0, len(geneset))
genes[indexA] = geneset[indexB]
class CardTests(unittest.TestCase):
def test(self):
geneset = [i + 1 for i in range(10)]
startTime = datetime.datetime.now()
def fnDisplay(candidate):
display(candidate, startTime)
def fnGetFitness(genes):
return get_fitness(genes)
def fnMutate(genes):
mutate(genes, geneset)
optimalFitness = Fitness(36, 360, 0)
best = genetic.get_best(fnGetFitness, 10, optimalFitness, geneset,
fnDisplay, custom_mutate=fnMutate)
self.assertTrue(not optimalFitness > best.Fitness)
def test_benchmark(self):
genetic.Benchmark.run(lambda: self.test())
class Fitness:
def __init__(self, group1Sum, group2Product, duplicateCount):
self.Group1Sum = group1Sum
self.Group2Product = group2Product
sumDifference = abs(36 - group1Sum)
productDifference = abs(360 - group2Product)
self.TotalDifference = sumDifference + productDifference
self.DuplicateCount = duplicateCount
def __gt__(self, other):
if self.DuplicateCount != other.DuplicateCount:
return self.DuplicateCount < other.DuplicateCount
return self.TotalDifference < other.TotalDifference
def __str__(self):
return "sum: {} prod: {} dups: {}".format(
self.Group1Sum,
self.Group2Product,
self.DuplicateCount)
if __name__ == '__main__':
unittest.main()
接下来对上面的重要部分进行上述代码的重要部分简单的讲解:
class Fitness:
def __init__(self, group1Sum, group2Product, duplicateCount):
self.Group1Sum = group1Sum
self.Group2Product = group2Product
sumDifference = abs(36 - group1Sum)
productDifference = abs(360 - group2Product)
self.TotalDifference = sumDifference + productDifference
self.DuplicateCount = duplicateCount
def __gt__(self, other):
if self.DuplicateCount != other.DuplicateCount:
return self.DuplicateCount < other.DuplicateCount
return self.TotalDifference < other.TotalDifference
def __str__(self):
return "sum: {} prod: {} dups: {}".format(
self.Group1Sum,
self.Group2Product,
self.DuplicateCount)
适应值包含了第一粗数据的和、第二组数据的积、重复数据个数,以及总差异(即第一组数据的和与36的差的绝对值+第二组数据的积和360的差的绝对值)。重复数越小越好,如果重复值相同,则总差异越小越好。
另一个比较重要的更改是,上述程序自定义了突变函数,并在调用engine是传入了该函数。
def mutate(genes, geneset):
if len(genes) == len(set(genes)):
count = random.randint(1, 4)
while count > 0:
count -= 1
indexA, indexB = random.sample(range(len(genes)), 2)
genes[indexA], genes[indexB] = genes[indexB], genes[indexA]
else:
indexA = random.randrange(0, len(genes))
indexB = random.randrange(0, len(geneset))
genes[indexA] = geneset[indexB]
在函数mutate中,首先判断这个基因是不是不包含重复的字母,如果包含就随机选择一个位置进行替换,如果不包含,则进行1~4次的突变,这个突变为讲两个位置的数进行交换。