import matplotlib.pyplot as plt decisionNode=dict(boxstyle="sawtooth",fc="11.8"); leafNode=dict(boxstyle="round4",fc="0.8"); arrow_args=dict(arrowstyle="<-"); def plotNode(nodeTxt,centerPt,parentPt,nodeType): createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args); def createPlot(): fig=plt.figure(1,facecolor='white'); fig.clf(); createPlot.ax1=plt.subplot(111,frameon=False); plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode); plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode); plt.show(); def getNumLeafs(myTree): numLeafs=0; firstStr=myTree.keys()[0]; print firstStr; secondDict=myTree[firstStr]; print secondDict; for key in secondDict.keys(): print key; print secondDict[key]; if type(secondDict[key]).__name__=='dict': numLeafs+=getNumLeafs(secondDict[key]); else: numLeafs+=1; return numLeafs; def getTreeDepth(myTree): maxDepth=0; firstStr=myTree.keys()[0]; secondDict=myTree[firstStr]; for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth=1+getTreeDepth(secondDict[key]); else: thisDepth=1; if thisDepth >maxDepth: maxDepth=thisDepth; return maxDepth; def retrieveTree(i): listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]; return listOfTrees[i]; def plotMidText(cntrPt,parentPt,txtString): #计算父亲和孩子中间地方 放文本 xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]; yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]; createPlot.ax1.text(xMid,yMid,txtString); def plotTree(myTree,parentPt,nodeTxt): numLeafs=getNumLeafs(myTree); depth=getTreeDepth(myTree); firstStr=myTree.keys()[0]; cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff); plotMidText(cntrPt,parentPt,nodeTxt); #画中间 plotNode(firstStr,cntrPt,parentPt,decisionNode);#画这个点 secondDict=myTree[firstStr]; plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD; for key in secondDict.keys(): #不是叶子 if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)); else: plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW; #是叶子 plotNode(secondDict[key],(plotTree.xOff,plot.yOff),cntrPt,leafNode); plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)); plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD; def createPlot(inTree): fig=plt.figure(1,facecolor='white'); fig.clf(); axprops=dict(xticks=[],yticks=[]); createPlot.ax1=plt.subplot(111,frameon=False,**axprops); plotTree.totalW=float(getNumLeafs(inTree)); plotTree.totalD=float(getTreeDepth(inTree)); plotTree.xOff=-0.5/plotTree.totalW; #个人感觉这边是有点神奇的 plotTree.yOff=1.0; print plotTree.xOff,plotTree.yOff; plotTree(inTree,(0.5,1.0),''); plt.show();