zoukankan      html  css  js  c++  java
  • 10、Python 数据分析-Matplotlib绘图大全详解

    第一章 Matplotlib 简介

    Matplotlib 能够创建多数类型的图表,如条形图,散点图,条形图,饼图,堆叠图,3D 图和地图图表。

    首先,为了实际使用 Matplotlib,我们需要安装它。

    pip install matplotlib
    

    一旦你安装了 Python,你就做好了准备,你可以编写任何你想要的逻辑。 我喜欢使用 IDLE 来编程,但你可以随意使用任何你喜欢的东西。

    import matplotlib.pyplot as plt
    

    这一行导入集成的 pyplot ,我们将在整个系列中使用它。 我们将 pyplot 导入为 plt ,这是使用 pylot 的 python 程序的传统惯例。

    plt.plot([1,2,3],[5,7,4])
    

    接下来,我们调用 plot 的 .plot 方法绘制一些坐标。 这个 .plot 需要许多参数,但前两个是 'x' 和 'y' 坐标,我们放入列表。 这意味着,根据这些列表我们拥有 3 个坐标: 1,5 2,7 和 3,4 。

    plt.plot 在后台『绘制』这个绘图,但绘制了我们想要的一切之后,当我们准备好的时候,我们需要把它带到屏幕上。

    plt.show()
    

    这样,应该弹出一个图形。 如果没有,有时它可以弹出,或者你可能得到一个错误。 你的图表应如下所示:

    img

    这个窗口是一个 matplotlib 窗口,它允许我们查看我们的图形,以及与它进行交互和访问。 你可以将鼠标悬停在图表上,并查看通常在右下角的坐标。 你也可以使用按钮。 它们可能在不同的位置,但在上图中,这些按钮在左下角。

    第二章 图例、标题和标签

    在本教程中,我们将讨论 Matplotlib 中的图例,标题和标签。 很多时候,图形可以不言自明,但是图形带有标题,轴域上的标签和图例,来解释每一行是什么非常必要。

    注:轴域(Axes)即两条坐标轴围城的区域。

    从这里开始:

    import matplotlib.pyplot as plt
     
    x = [1,2,3]
    y = [5,7,4]
     
    x2 = [1,2,3]
    y2 = [10,14,12]
    

    这样我们可以画出两个线条,接下来:

    plt.plot(x, y, label='First Line')
    plt.plot(x2, y2, label='Second Line')
    

    在这里,我们绘制了我们已经看到的东西,但这次我们添加另一个参数label。 这允许我们为线条指定名称,我们以后可以在图例中显示它。 我们的其余代码为:

    plt.xlabel('Plot Number')
    plt.ylabel('Important var')
    plt.title('Interesting Graph
    Check it out')
    plt.legend()
    plt.show()
    

    使用 plt.xlabel 和 plt.ylabel ,我们可以为这些相应的轴创建标签。 接下来,我们可以使用 plt.title 创建图的标题,然后我们可以使用 plt.legend() 生成默认图例。 结果图如下:

    img

    第三章 条形图和直方图

    这个教程中我们会涉及条形图和直方图。我们先来看条形图:

    import matplotlib.pyplot as plt
     
    plt.bar([1,3,5,7,9],[5,2,7,8,2], label="Example one")
     
    plt.bar([2,4,6,8,10],[8,6,2,5,6], label="Example two", color='g')
    plt.legend()
    plt.xlabel('bar number')
    plt.ylabel('bar height')
     
    plt.title('Epic Graph
    Another Line! Whoa')
     
    plt.show()
    

    plt.bar 为我们创建条形图。 如果你没有明确选择一种颜色,那么虽然做了多个图,所有的条看起来会一样。 这让我们有机会使用一个新的 Matplotlib 自定义选项。 你可以在任何类型的绘图中使用颜色,例如 g 为绿色, b 为蓝色, r 为红色,等等。 你还可以使用十六进制颜色代码,如 #191970 。

    img

    接下来,我们会讲解直方图。 直方图非常像条形图,倾向于通过将区段组合在一起来显示分布。 这个例子可能是年龄的分组,或测试的分数。 我们并不是显示每一组的年龄,而是按照 20 ~ 25,25 ~ 30… 等等来显示年龄。 这里有一个例子:

    import matplotlib.pyplot as plt
     
    population_ages =[22,55,62,45,21,22,34,42,42,4,99,102,110,120,121,122,130,111,115,112,80,75,65,54,44,43,42,48]
     
    bins = [0,10,20,30,40,50,60,70,80,90,100,110,120,130]
     
    plt.hist(population_ages, bins, histtype='bar', rwidth=0.8)
     
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Interesting Graph
    Check it out')
    plt.legend()
    plt.show()
    

    产生的图表为:

    img

    对于 plt.hist ``,你首先需要放入所有的值,然后指定放入哪个桶或容器。 在我们的例子中,我们绘制了一堆年龄,并希望以 10 年的增量来显示它们。 我们将条形的宽度设为 0.8,但是如果你想让条形变宽,或者变窄,你可以选择其他的宽度。

    第四章 散点图

    接下来,我们将介绍散点图。散点图通常用于比较两个变量来寻找相关性或分组,如果你在 3 维绘制则是 3 个。

    散点图的一些示例代码:

    import matplotlib.pyplot as plt
     
    x = [1,2,3,4,5,6,7,8]
    y = [5,2,4,2,1,4,5,2]
     
    plt.scatter(x,y, label='skitscat', color='k', s=25, marker="o")
     
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Interesting Graph
    Check it out')
    plt.legend()
    plt.show()
    

    结果为:

    img

    plt.scatter 不仅允许我们绘制 x 和 y ,而且还可以让我们决定所使用的标记颜色,大小和类型。 有一堆标记选项,请参阅 Matplotlib 标记文档中的所有选项。

    第五章 堆叠图

    在这篇 Matplotlib 数据可视化教程中,我们要介绍如何创建堆叠图。 堆叠图用于显示『部分对整体』随时间的关系。 堆叠图基本上类似于饼图,只是随时间而变化。

    让我们考虑一个情况,我们一天有 24 小时,我们想看看我们如何花费时间。 我们将我们的活动分为:睡觉,吃饭,工作和玩耍。

    我们假设我们要在 5 天的时间内跟踪它,因此我们的初始数据将如下所示:

    import matplotlib.pyplot as plt
     
    days = [1,2,3,4,5]
     
    sleeping = [7,8,6,11,7]
    eating =   [2,3,4,3,2]
    working =  [7,8,7,2,2]
    playing =  [8,5,7,8,13]
    

    因此,我们的 x 轴将包括 day 变量,即 1, 2, 3, 4 和 5。然后,日期的各个成分保存在它们各自的活动中。 像这样绘制它们:

    plt.stackplot(days, sleeping,eating,working,playing, colors=['m','c','r','k'])
     
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Interesting Graph
    Check it out')
    plt.show()
    

    img

    在这里,我们可以至少在颜色上看到,我们如何花费我们的时间。 问题是,如果不回头看代码,我们不知道什么颜色是什么。 下一个问题是,对于多边形来说,我们实际上不能为数据添加『标签』。 因此,在任何不止是线条,带有像这样的填充或堆叠图的地方,我们不能以固有方式标记出特定的部分。 这不应该阻止程序员。 我们可以解决这个问题:

    import matplotlib.pyplot as plt
     
    days = [1,2,3,4,5]
     
    sleeping = [7,8,6,11,7]
    eating =   [2,3,4,3,2]
    working =  [7,8,7,2,2]
    playing =  [8,5,7,8,13]
     
     
    plt.plot([],[],color='m', label='Sleeping', linewidth=5)
    plt.plot([],[],color='c', label='Eating', linewidth=5)
    plt.plot([],[],color='r', label='Working', linewidth=5)
    plt.plot([],[],color='k', label='Playing', linewidth=5)
     
    plt.stackplot(days, sleeping,eating,working,playing, colors=['m','c','r','k'])
     
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Interesting Graph
    Check it out')
    plt.legend()
    plt.show()
    

    img

    我们在这里做的是画一些空行,给予它们符合我们的堆叠图的相同颜色,和正确标签。 我们还使它们线宽为 5,使线条在图例中显得较宽。 现在,我们可以很容易地看到,我们如何花费我们的时间。

    第六章 饼图

    饼图很像堆叠图,只是它们位于某个时间点。 通常,饼图用于显示部分对于整体的情况,通常以%为单位。 幸运的是,Matplotlib 会处理切片大小以及一切事情,我们只需要提供数值。

    import matplotlib.pyplot as plt
     
    slices = [7,2,2,13]
    activities = ['sleeping','eating','working','playing']
    cols = ['c','m','r','b']
     
    plt.pie(slices,
            labels=activities,
            colors=cols,
            startangle=90,
            shadow= True,
            explode=(0,0.1,0,0),
            autopct='%1.1f%%')
     
    plt.title('Interesting Graph
    Check it out')
    plt.show()
    

    img

    在 plt.pie 中,我们需要指定『切片』,这是每个部分的相对大小。 然后,我们指定相应切片的颜色列表。 接下来,我们可以选择指定图形的『起始角度』。 这使你可以在任何地方开始绘图。 在我们的例子中,我们为饼图选择了 90 度角,这意味着第一个部分是一个竖直线条。 接下来,我们可以选择给绘图添加一个字符大小的阴影,然后我们甚至可以使用 explode ``拉出一个切片。

    我们总共有四个切片,所以对于 explode ``,如果我们不想拉出任何切片,我们传入 0,0,0,0 。 如果我们想要拉出第一个切片,我们传入 0.1,0,0,0 。

    最后,我们使用 autopct ,选择将百分比放置到图表上面。

    第七章 从文件加载数据

    很多时候,我们想要绘制文件中的数据。 有许多类型的文件,以及许多方法,你可以使用它们从文件中提取数据来图形化。 在这里,我们将展示几种方法。 首先,我们将使用内置的 csv 模块加载CSV文件,然后我们将展示如何使用 NumPy(第三方模块)加载文件。

    import matplotlib.pyplot as plt
    import csv
     
    x = []
    y = []
     
    with open('example.txt','r') as csvfile:
        plots = csv.reader(csvfile, delimiter=',')
        for row in plots:
            x.append(int(row[0]))
            y.append(int(row[1]))
     
    plt.plot(x,y, label='Loaded from file!')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Interesting Graph
    Check it out')
    plt.legend()
    plt.show()
    

    img

    这里,我们打开样例文件,包含以下数据:

    1,5
    2,3
    3,4
    4,7
    5,4
    6,3
    7,5
    8,7
    9,4
    10,4
    

    接下来,我们使用 csv 模块读取数据。 csv 读取器自动按行分割文件,然后使用我们选择的分隔符分割文件中的数据。 在我们的例子中,这是一个逗号。 注意: csv 模块和 csv reader 不需要文件在字面上是一个.csv文件。 它可以是任何具有分隔数据的简单的文本文件。

    一旦我们这样做了,我们将索引为 0 的元素存储到 x 列表,将索引为 1 的元素存储到 y 列表中。 之后,我们都设置好了,准备绘图,然后显示数据。

    虽然使用 CSV 模块是完全正常的,但使用 NumPy 模块来加载我们的文件和数据,可能对我们更有意义。 如果你没有 NumPy,你需要按下面的步骤来获取它。 为了了解安装模块的更多信息,请参阅 pip 教程。 大多数人应该都能打开命令行,并执行 pip install numpy 。

    如果不能,请参阅链接中的教程。

    一旦你安装了 NumPy,你可以编写如下代码:

    import matplotlib.pyplot as plt
    import numpy as np
     
    x, y = np.loadtxt('example.txt', delimiter=',', unpack=True)
    plt.plot(x,y, label='Loaded from file!')
     
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Interesting Graph
    Check it out')
    plt.legend()
    plt.show()
    

    结果应该是相同的图表。 稍后,当我们加载数据时,我们可以利用 NumPy 为我们做一些更多的工作,但这是教程未来的内容。 就像csv模块不需要一个特地的.csv一样,loadtxt函数不要求文件是一个.txt文件,它可以是一个.csv,它甚至可以是一个 python 列表对象。

    第八章 从网络加载数据

    除了从文件加载数据,另一个流行的数据源是互联网。 我们可以用各种各样的方式从互联网加载数据,但对我们来说,我们只是简单地读取网站的源代码,然后通过简单的拆分来分离数据。

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import matplotlib.dates as mdates
     
     
    def graph_data(stock):
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
     
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
     
        stock_data = []
        split_source = source_code.split('
    ')
     
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line:
                    stock_data.append(line)
    

    这里有很多步骤。首先,我们看到importpyplot像往常一样导入,然后导入了numpy,然后是用于访问互联网的urllib,然后导入了matplotlib.dates作为mdates,它对于将日期戳转换为 matplotlib 可以理解的日期很有用。

    接下来,我们开始构建我们的graph_data函数。在这里,我们首先定义包含股票数据的网址。之后,我们写一些urllib代码来访问该 URL,然后使用.read读取源代码,之后我们继续解码该数据。如果你使用 Python 2,则不必使用decode

    然后,我们定义一个空列表,这是我们将要放置股票数据的地方,我们也开始使用split_source变量拆分数据,以换行符拆分。

    现在,如果你去看源代码,用stock替换 URL 中的+stock+,像 AAPL 那样,你可以看到大多数页面数据确实是股票定价信息,但有一些头信息我们需要过滤掉。为此,我们使用一些基本的过滤,检查它们来确保每行有 6 个数据点,然后确保术语values不在行中。

    现在,我们已经解析了数据,并做好了准备。我们将使用 NumPy:

    date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                          delimiter=',',
                                                          unpack=True,
                                                          # %Y = full year. 2015
                                                          # %y = partial year 15
                                                          # %m = number month
                                                          # %d = number day
                                                          # %H = hours
                                                          # %M = minutes
                                                          # %S = seconds
                                                          # 12-06-2014
                                                          # %m-%d-%Y
                                                          converters={0: bytespdate2num('%Y%m%d')})
    

    我们在这里所做的是,使用numpyloadtxt函数,并将这六个元素解构到六个变量。 这里的第一个参数是stock_data,这是我们加载的数据。 然后,我们指定delimiter(这里是逗号),然后我们指定我们确实想要在这里解包变量,不是一个变量,而是我们定义的这组变量。 最后,我们使用可选的converters参数来指定我们要转换的元素(0),以及我们打算要怎么做。 我们传递一个名为bytespdate2num的函数,它还不存在,但我们下面会编写它。

    第九章 时间戳的转换

    本教程的重点是将来自 Yahoo finance API 的日期转换为 Matplotlib 可理解的日期。 为了实现它,我们要写一个新的函数,bytespdate2num

    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
    

    此函数接受数据,基于编码来解码数据,然后返回它。

    将此应用于我们的程序的其余部分:

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import matplotlib.dates as mdates
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              # %Y = full year. 2015
                                                              # %y = partial year 15
                                                              # %m = number month
                                                              # %d = number day
                                                              # %H = hours
                                                              # %M = minutes
                                                              # %S = seconds
                                                              # 12-06-2014
                                                              # %m-%d-%Y
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        plt.plot_date(date, closep,'-', label='Price')
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title('Interesting Graph
    Check it out')
        plt.legend()
        plt.show()
     
     
    graph_data('TSLA')
    

    如果你绘制 TSLA,结果图应该看起来像这样:

    img

    第十章 基本的自定义

    在 Matplotlib 教程中,我们将讨论一些可能的图表自定义。 为了开始修改子图,我们必须定义它们。 我们很快会谈论他们,但有两种定义并构造子图的主要方法。 现在,我们只使用其中一个,但我们会很快解释它们。

    现在,修改我们的graph_data函数:

    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
    

    为了修改图表,我们需要引用它,所以我们将它存储到变量fig。 然后我们将ax1定义为图表上的子图。 我们在这里使用subplot2grid,这是获取子图的两种主要方法之一。 我们将深入讨论这些东西,但现在,你应该看到我们有 2 个元组,它们提供了(1,1)(0,0)1,1表明这是一个 1×1 网格。 然后0,0表明这个子图的『起点』将为0,0

    接下来,通过我们已经编写的代码来获取和解析数据:

    stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
    source_code = urllib.request.urlopen(stock_price_url).read().decode()
    stock_data = []
    split_source = source_code.split('
    ')
    for line in split_source:
        split_line = line.split(',')
        if len(split_line) == 6:
            if 'values' not in line and 'labels' not in line:
                stock_data.append(line)
     
    date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                          delimiter=',',
                                                          unpack=True,
                                                          converters={0: bytespdate2num('%Y%m%d')})
    

    下面,我们这样绘制数据:

    ax1.plot_date(date, closep,'-', label='Price')
    

    现在,由于我们正在绘制日期,我们可能会发现,如果我们放大,日期会在水平方向上移动。但是,我们可以自定义这些刻度标签,像这样:

    for label in ax1.xaxis.get_ticklabels():
        label.set_rotation(45)
    

    这将使标签转动到对角线方向。 接下来,我们可以添加一个网格:

    ax1.grid(True)
    

    然后,其它东西我们保留默认,但我们也可能需要略微调整绘图,因为日期跑到了图表外面。 记不记得我们在第一篇教程中讨论的configure subplots按钮? 我们不仅可以以这种方式配置图表,我们还可以在代码中配置它们,以下是我们设置这些参数的方式:

    plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
    

    现在,为了防止我们把你遗留在某个地方,这里是完整的代码:

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import matplotlib.dates as mdates
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        ax1.plot_date(date, closep,'-', label='Price')
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
        ax1.grid(True)#, color='g', linestyle='-', linewidth=5)
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title('Interesting Graph
    Check it out')
        plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('TSLA')
    

    结果为:

    img

    第十一章 Unix 时间

    在这个 Matplotlib 教程中,我们将介绍如何处理 unix 时间戳的转换,然后在图形中绘制日期戳。 使用 Yahoo Finance API,你会注意到,如果你使用较大的时间间隔,如1y(一年),你会得到我们一直在使用的日期戳,但如果你使用10d(10 天),反之你会得到 unix 时间的时间戳。

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import datetime as dt
    import matplotlib.dates as mdates
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10d/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True)
        dateconv = np.vectorize(dt.datetime.fromtimestamp)
        date = dateconv(date)
     
    ##    date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
    ##                                                          delimiter=',',
    ##                                                          unpack=True,
    ##                                                          converters={0: bytespdate2num('%Y%m%d')})
     
        ax1.plot_date(date, closep,'-', label='Price')
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
        ax1.grid(True)#, color='g', linestyle='-', linewidth=5)
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title('Interesting Graph
    Check it out')
        plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('TSLA')
    

    Unix 时间是 1970 年 1 月 1 日以后的秒数,它是跨程序的标准化时间表示方法。 也就是说,Matplotlib 并不欢迎 unix 时间戳。 这里是你可以使用 Matplotlib 来处理 unix 时间的方式:

    所以在这里,我们所做的是 unix 时间的写入处理,并注释掉我们以前的代码,因为我们为之后的使用而保存它。 这里的主要区别是:

    dateconv = np.vectorize(dt.datetime.fromtimestamp)
    date = dateconv(date)
    

    这里,我们将时间戳转换为日期戳,然后将其转换为 Matplotlib 想要的时间。

    现在,由于某些原因,我的 unix 时间带有另一行包含 6 个元素的数据,并且它包含了术语label,因此,在我们解析数据的for循环中,我们为你再添加一个需要注意的检查:

    for line in split_source:
        split_line = line.split(',')
        if len(split_line) == 6:
            if 'values' not in line and 'labels' not in line:
                stock_data.append(line)
    

    现在你的图表应该类似:

    img

    这里的所有扁平线条的原因是市场关闭。 有了这个短期数据,我们可以得到日内数据。 所以交易开放时有很多点,然后市场关闭时就没有了,然后又是一堆,然后又是没有。

    第十二章 颜色和填充

    在本教程中,我们将介绍一些更多的自定义,比如颜色和线条填充。

    我们要做的第一个改动是将plt.title更改为stock变量。

    plt.title(stock)
    

    现在,让我们来介绍一下如何更改标签颜色。 我们可以通过修改我们的轴对象来实现:

    ax1.xaxis.label.set_color('c')
    ax1.yaxis.label.set_color('r')
    

    如果我们运行它,我们会看到标签改变了颜色,就像在单词中那样。

    接下来,我们可以为要显示的轴指定具体数字,而不是像这样的自动选择:

    ax1.set_yticks([0,25,50,75])
    

    接下来,我想介绍填充。 填充所做的事情,是在变量和你选择的一个数值之间填充颜色。 例如,我们可以这样:

    ax1.fill_between(date, 0, closep)
    

    所以到这里,我们的代码为:

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import datetime as dt
    import matplotlib.dates as mdates
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        ax1.fill_between(date, 0, closep)
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
        ax1.grid(True)#, color='g', linestyle='-', linewidth=5)
        ax1.xaxis.label.set_color('c')
        ax1.yaxis.label.set_color('r')
        ax1.set_yticks([0,25,50,75])
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    结果为:

    img

    填充的一个问题是,我们可能最后会把东西都覆盖起来。 我们可以用透明度来解决它:

    ax1.fill_between(date, 0, closep)
    

    现在,让我们介绍条件填充。 让我们假设图表的起始位置是我们开始买入 eBay 的地方。 这里,如果价格低于这个价格,我们可以向上填充到原来的价格,然后如果它超过了原始价格,我们可以向下填充。 我们可以这样做:

    ax1.fill_between(date, closep[0], closep)
    

    会生成:

    img

    如果我们想用红色和绿色填充来展示收益/损失,那么与原始价格相比,绿色填充用于上升(注:国外股市的颜色和国内相反),红色填充用于下跌? 没问题! 我们可以添加一个where参数,如下所示:

    ax1.fill_between(date, closep, closep[0],where=(closep > closep[0]), facecolor='g', alpha=0.5)
    

    这里,我们填充当前价格和原始价格之间的区域,其中当前价格高于原始价格。 我们给予它绿色的前景色,因为这是一个上升,而且我们使用微小的透明度。

    我们仍然不能对多边形数据(如填充)应用标签,但我们可以像以前一样实现空线条,因此我们可以:

    ax1.plot([],[],linewidth=5, label='loss', color='r',alpha=0.5)
    ax1.plot([],[],linewidth=5, label='gain', color='g',alpha=0.5)
     
    ax1.fill_between(date, closep, closep[0],where=(closep > closep[0]), facecolor='g', alpha=0.5)
    ax1.fill_between(date, closep, closep[0],where=(closep < closep[0]), facecolor='r', alpha=0.5)
    

    这向我们提供了一些填充,以及用于处理标签的空线条,我们打算将其显示在图例中。这里完整的代码是:

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import datetime as dt
    import matplotlib.dates as mdates
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        ax1.plot_date(date, closep,'-', label='Price')
     
        ax1.plot([],[],linewidth=5, label='loss', color='r',alpha=0.5)
        ax1.plot([],[],linewidth=5, label='gain', color='g',alpha=0.5)
     
        ax1.fill_between(date, closep, closep[0],where=(closep > closep[0]), facecolor='g', alpha=0.5)
        ax1.fill_between(date, closep, closep[0],where=(closep < closep[0]), facecolor='r', alpha=0.5)
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
        ax1.grid(True)#, color='g', linestyle='-', linewidth=5)
        ax1.xaxis.label.set_color('c')
        ax1.yaxis.label.set_color('r')
        ax1.set_yticks([0,25,50,75])
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    现在我们的结果是:

    img

    第十三章 边框和水平线条

    欢迎阅读另一个定制教程,在这里我们使用 Matplotlib 讨论边框和水平线条。 有时候你可能想做的事情是改变边框的颜色,或者甚至完全删除它们。

    图形的边框基本上是图形的边界,其中有刻度线等东西。为了改变边框的颜色,你可以做一些类似这样的事情:

    ax1.spines['left'].set_color('c')
    

    在这里,我们引用了我们的边框字典,表示我们要调整左边框,然后我们使用set_color方法将颜色设置为'c',它是青色。

    如果我们想删除所有边框怎么办? 我们可以这样做:

    ax1.spines['right'].set_visible(False)ax1.spines['top'].set_visible(False)
    

    这是非常类似的代码,删除了右边框和上边框。

    很难看到我们修改了左边框的颜色,所以让我们通过修改线宽来使它变得很明显:

    ax1.spines['left'].set_linewidth(5)
    

    现在,左边框变成了非常粗也非常显眼的青色。 最后,如果我们想修改刻度参数怎么办? 假如不想要黑色的日期,我们想要一些橙色的日期? 没问题!

    ax1.tick_params(axis='x', colors='#f06215')
    

    现在我们的日期是橙色了! 接下来,让我们来看看我们如何绘制一条水平线。 你当然可以将你创建的一组新数据绘制成一条水平线,但你不需要这样做。 你可以:

    ax1.axhline(closep[0], color='k', linewidth=5)
    

    所以在这里,我们的整个代码是:

    import matplotlib.pyplot as plt
    import numpy as np
    import urllib
    import datetime as dt
    import matplotlib.dates as mdates
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=10y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        ax1.plot_date(date, closep,'-', label='Price')
        ax1.plot([],[],linewidth=5, label='loss', color='r',alpha=0.5)
        ax1.plot([],[],linewidth=5, label='gain', color='g',alpha=0.5)
        ax1.axhline(closep[0], color='k', linewidth=5)
        ax1.fill_between(date, closep, closep[0],where=(closep > closep[0]), facecolor='g', alpha=0.5)
        ax1.fill_between(date, closep, closep[0],where=(closep < closep[0]), facecolor='r', alpha=0.5)
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
        ax1.grid(True)
        #ax1.xaxis.label.set_color('c')
        #ax1.yaxis.label.set_color('r')
        ax1.set_yticks([0,25,50,75])
     
        ax1.spines['left'].set_color('c')
        ax1.spines['right'].set_visible(False)
        ax1.spines['top'].set_visible(False)
        ax1.spines['left'].set_linewidth(5)
     
        ax1.tick_params(axis='x', colors='#f06215')
     
     
     
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('ebay')
    

    结果为:

    img

    第十四章 OHLC K 线图

    在 Matplotlib 教程中,我们将介绍如何在 Matplotlib 中创建开,高,低,关(OHLC)的 K 线图。 这些图表用于以精简形式显示时间序列股价信息。 为了实现它,我们首先需要导入一些模块:

    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    

    我们引入了ticker,允许我们修改图表底部的ticker信息。 然后我们从matplotlib.finance模块中引入candlestick_ohlc功能。

    现在,我们需要组织我们的数据来和 matplotlib 协作。 如果你刚刚加入我们,我们得到的数据如下:

    stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
    source_code = urllib.request.urlopen(stock_price_url).read().decode()
    stock_data = []
    split_source = source_code.split('
    ')
    for line in split_source:
        split_line = line.split(',')
        if len(split_line) == 6:
            if 'values' not in line and 'labels' not in line:
                stock_data.append(line)
     
     
    date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                          delimiter=',',
                                                          unpack=True,
                                                          converters={0: bytespdate2num('%Y%m%d')})
    

    现在,我们需要构建一个 Python 列表,其中每个元素都是数据。 我们可以修改我们的loadtxt函数,使其不解构,但随后我们还是希望引用特定的数据点。 我们可以解决这个问题,但是我们最后可能只拥有两个单独的数据集。 为此,我们执行以下操作:

    x = 0
    y = len(date)
    ohlc = []
     
    while x < y:
        append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
        ohlc.append(append_me)
        x+=1
    

    有了这个,我们现在将 OHLC 数据列表存储到我们的变量ohlc。 现在我们可以这样绘制:

    candlestick_ohlc(ax1, ohlc)
    

    图表应该是这样:

    img

    不幸的是,x轴上的datetime数据不是日期戳的形式。 我们可以处理它:

    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    

    此外,红/黑着色依我看不是最好的选择。 我们应该使用绿色表示上升和红色表示下降。 为此,我们可以:

    candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
    

    最后,我们可以将x标签设置为我们想要的数量,像这样:

    ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
    

    现在,完整代码现在是这样:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
     
    import numpy as np
    import urllib
    import datetime as dt
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
     
        candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax1.grid(True)
     
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    结果为:

    img

    还要注意,我们从前面的教程中删除了大部分ax1的修改。

    第十五章 样式

    在这个 Matplotlib 教程中,我们将讨论样式。 我们用于 Matplotlib 的样式非常相似于用于 HTML 页面的 CSS(层叠样式表)。 正如你在这里可以看到的,我们对图形所做的所有修改都会叠加,而且我们目前只有一个轴域。 我们可以使用for循环,至少使代码量降低,但我们也可以在 Matplotlib 中利用这些样式。

    样式页的想法是将自定义样式写入文件,然后,为了使用这些更改并将其应用于图形,所有你需要做的就是导入样式,然后使用该特定样式。 这样,让我们假设你发现自己总是改变图形的各种元素。 你不必为每个图表编写 25 ~ 200 行自定义代码,只需将其写入一个样式,然后加载该样式,并以两行应用所有这些更改即可! 让我们开始吧。

    from matplotlib import style
    

    接下来,我们指定要使用的样式。 Matplotlib 已经有了几种样式。

    我们可以这样来使用样式:

    style.use('ggplot')
    

    img

    除了标题,标签的颜色是灰色的,轴域的背景是浅灰色,我们可以立即分辨字体是不同的。 我们还注意到,网格实际上是一个白色的实线。 我们的 K 线图保持不变,主要是因为我们在事后定制它。 在样式中加载时,更改会生效,但如果在加载样式后编写新的自定义代码,你的更改也会生效。

    因为我们试图展示样式模块,但是让我们继续,简单绘制几行,并暂且注释掉 K 线图:

    #candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
    ax1.plot(date,closep)
    ax1.plot(date,openp)
    

    会生成:

    img

    已经比默认值好多了!

    样式的另一个例子是fivethirtyeight

    img

    你可以这样查看所有的可用样式:

    print(plt.style.available)
    

    我这里它提供了['bmh', 'dark_background', 'ggplot', 'fivethirtyeight', 'grayscale']

    让我们尝试dark_background

    style.use('dark_background')
    
    1

    img

    现在,如果你想制作自己的风格呢? 首先,你需要找到样式目录。 为了实现它,如果你知道它在哪里,你可以前往你的 matplotlib 目录,或者你可以找到该目录。 如果你不知道如何找到该目录,你可以执行以下操作:

    print(plt.__file__)
    

    这至少会告诉你pyplot模块的位置。

    在 matplotlib 目录中,你需要寻找mpl-data。 然后在那里,你需要寻找stylelib。 在 Windows 上 ,我的完整路径是:C:Python34Libsite-packagesmatplotlibmpl-datastylelib

    那里应该显示了所有可用的.mplstyle文件。 你可以编辑、复制或重命名它们,然后在那里修改为你想要的东西。 然后,无论你用什么来命名.mplstyle文件,都要放在style.use中。

    第十六章 实时图表

    在这篇 Matplotlib 教程中,我们将介绍如何创建实时更新图表,可以在数据源更新时更新其图表。 你可能希望将此用于绘制股票实时定价数据,或者可以将传感器连接到计算机,并且显示传感器实时数据。 为此,我们使用 Matplotlib 的动画功能。

    最开始:

    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    from matplotlib import style
    

    这里,唯一的新增导入是matplotlib.animation as animation。 这是一个模块,允许我们在显示之后对图形进行动画处理。

    接下来,我们添加一些你熟悉的代码,如果你一直关注这个系列:

    style.use('fivethirtyeight')
     
    fig = plt.figure()
    ax1 = fig.add_subplot(1,1,1)
    

    现在我们编写动画函数:

    def animate(i):
        graph_data = open('example.txt','r').read()
        lines = graph_data.split('
    ')
        xs = []
        ys = []
        for line in lines:
            if len(line) > 1:
                x, y = line.split(',')
                xs.append(x)
                ys.append(y)
        ax1.clear()
        ax1.plot(xs, ys)
    

    我们在这里做的是构建数据,然后绘制它。 注意我们这里不调用plt.show()。 我们从一个示例文件读取数据,其内容如下:

    1,5
    2,3
    3,4
    4,7
    5,4
    6,3
    7,5
    8,7
    9,4
    10,4
    

    我们打开上面的文件,然后存储每一行,用逗号分割成xsys,我们将要绘制它。 然后:

    ani = animation.FuncAnimation(fig, animate, interval=1000)
    plt.show()
    

    我们运行动画,将动画放到图表中(fig),运行animate的动画函数,最后我们设置了 1000 的间隔,即 1000 毫秒或 1 秒。

    运行此图表的结果应该像往常一样生成图表。 然后,你应该能够使用新的坐标更新example.txt文件。 这样做会生成一个自动更新的图表,如下:

    img

    第十七章 注解和文本

    在本教程中,我们将讨论如何向 Matplotlib 图形添加文本。 我们可以通过两种方式来实现。 一种是将文本放置在图表上的某个位置。 另一个是专门注解图表上的绘图,来引起注意。

    这里的起始代码是教程 15,它在这里:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
     
        candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax1.grid(True)
     
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('ebay')
    

    所以这里是 Yahoo Finance API 的 eBay 的 OHLC K 线图。 这里我们要讲解的第一件事是向图形添加文本。

    font_dict = {'family':'serif',
                 'color':'darkred',
                 'size':15}
    ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict)
    

    在这里,我们需要做一些事情。 首先,我们使用ax1.text添加文本。 我们使用我们的数据,以坐标形式给出此文本的位置。 首先给出文本的坐标,然后给出要放置的实际文本。 接下来,我们使用fontdict参数添加一个数据字典,来使用所用的字体。 在我们的字体字典中,我们将字体更改为serif,颜色为『深红色』,然后将字体大小更改为 15。这将全部应用于我们的图表上的文本,如下所示:

    img

    太棒了,接下来我们可以做的是,注解某个特定的绘图。 我们希望这样做来给出更多的信息。 在 eBay 的例子中,也许我们想解释某个具体绘图,或给出一些关于发生了什么的信息。 在股价的例子中,也许有一些发生的新闻会影响价格。 你可以注解新闻来自哪里,这将有助于解释定价变化。

    ax1.annotate('Bad News!',(date[9],highp[9]),
                 xytext=(0.8, 0.9), textcoords='axes fraction',
                 arrowprops = dict(facecolor='grey',color='grey'))
    

    这里,我们用ax1.annotate来注解。 我们首先传递我们想要注解的文本,然后传递我们让这个注解指向的坐标。 我们这样做,是因为当我们注释时,我们可以绘制线条和指向特定点的箭头。 接下来,我们指定xytext的位置。 它可以是像我们用于文本放置的坐标位置,但是让我们展示另一个例子。 它可以为轴域小数,所以我们使用 0.8 和 0.9。 这意味着文本的位置在x轴的80%和y轴的90%处。 这样,如果我们移动图表,文本将保持在相同位置。

    执行它,会生成:

    img

    根据你学习这个教程的时间,所指向的点可能有所不同,这只是一个注解的例子,其中有一些合理的想法,即为什么我们需要注解一些东西。

    当图表启动时,请尝试单击平移按钮(蓝色十字),然后移动图表。 你会看到文本保持不动,但箭头跟随移动并继续指向我们想要的具体的点。 这很酷吧!

    最后一个图表的完整代码:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
     
        candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax1.grid(True)
        ax1.annotate('Bad News!',(date[9],highp[9]),
                     xytext=(0.8, 0.9), textcoords='axes fraction',
                     arrowprops = dict(facecolor='grey',color='grey'))
     
    ##    # Text placement example:
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        #plt.legend()
        plt.subplots_adjust(left=0.09, bottom=0.20, right=0.94, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('ebay')
    

    现在,使用注解,我们可以做一些其他事情,如注解股票图表的最后价格。 这就是我们接下来要做的。

    第十八章 注解股票图表的最后价格

    在这个 Matplotlib 教程中,我们将展示如何跟踪股票的最后价格的示例,通过将其注解到轴域的右侧,就像许多图表应用程序会做的那样。

    虽然人们喜欢在他们的实时图表中看到历史价格,他们也想看到最新的价格。 大多数应用程序做的是,在价格的y轴高度处注释最后价格,然后突出显示它,并在价格变化时,在框中将其略微移动。 使用我们最近学习的注解教程,我们可以添加一个bbox

    我们的核心代码是:

    bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
    ax1.annotate(str(closep[-1]), (date[-1], closep[-1]),
                 xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
    

    我们使用ax1.annotate来放置最后价格的字符串值。 我们不在这里使用它,但我们将要注解的点指定为图上最后一个点。 接下来,我们使用xytext将我们的文本放置到特定位置。 我们将它的y坐标指定为最后一个点的y坐标,x坐标指定为最后一个点的x坐标,再加上几个点。我们这样做是为了将它移出图表。 将文本放在图形外面就足够了,但现在它只是一些浮动文本。

    我们使用bbox参数在文本周围创建一个框。 我们使用bbox_props创建一个属性字典,包含盒子样式,然后是白色(w)前景色,黑色(k)边框颜色并且线宽为 1。 更多框样式请参阅 matplotlib 注解文档

    最后,这个注解向右移动,需要我们使用subplots_adjust来创建一些新空间:

    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.87, top=0.90, wspace=0.2, hspace=0)
    

    这里的完整代码如下:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
     
        candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax1.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax1.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+3, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax1.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text
    ##    ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        #plt.legend()
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.87, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    结果为:

    img

    第十九章 子图

    在这个 Matplotlib 教程中,我们将讨论子图。 有两种处理子图的主要方法,用于在同一图上创建多个图表。 现在,我们将从一个干净的代码开始。 如果你一直关注这个教程,那么请确保保留旧的代码,或者你可以随时重新查看上一个教程的代码。

    首先,让我们使用样式,创建我们的图表,然后创建一个随机创建示例绘图的函数:

    import random
    import matplotlib.pyplot as plt
    from matplotlib import style
     
    style.use('fivethirtyeight')
     
    fig = plt.figure()
     
    def create_plots():
        xs = []
        ys = []
     
        for i in range(10):
            x = i
            y = random.randrange(10)
     
            xs.append(x)
            ys.append(y)
        return xs, ys
    

    现在,我们开始使用add_subplot方法创建子图:

    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(212)
    

    它的工作原理是使用 3 个数字,即:行数(numRows)、列数(numCols)和绘图编号(plotNum)。

    所以,221 表示两行两列的第一个位置。222 是两行两列的第二个位置。最后,212 是两行一列的第二个位置。

    2x2:
     
    +-----+-----+
    |  1  |  2  |
    +-----+-----+
    |  3  |  4  |
    +-----+-----+
     
    2x1:
     
    +-----------+
    |     1     |
    +-----------+
    |     2     |
    +-----------+
     
    译者注:<code>221</code>是缩写形式,仅在行数乘列数小于 10 时有效,否则要写成<code>2,2,1</code>。
    

    此代码结果为:

    img

    这就是add_subplot。 尝试一些你认为可能很有趣的配置,然后尝试使用add_subplot创建它们,直到你感到满意。

    接下来,让我们介绍另一种方法,它是subplot2grid

    删除或注释掉其他轴域定义,然后添加:

    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
    ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
    ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
    

    所以,add_subplot不能让我们使一个绘图覆盖多个位置。 但是这个新的subplot2grid可以。 所以,subplot2grid的工作方式是首先传递一个元组,它是网格形状。 我们传递了(6,1),这意味着整个图表分为六行一列。 下一个元组是左上角的起始点。 对于ax1,这是0,0,因此它起始于顶部。 接下来,我们可以选择指定rowspancolspan。 这是轴域所占的行数和列数。

    6x1:
     
              colspan=1
    (0,0)   +-----------+
            |    ax1    | rowspan=1
    (1,0)   +-----------+
            |           |
            |    ax2    | rowspan=4
            |           |
            |           |
    (5,0)   +-----------+
            |    ax3    | rowspan=1
            +-----------+
    

    结果为:

    img

    显然,我们在这里有一些重叠的问题,我们可以调整子图来处理它。

    再次,尝试构思各种配置的子图,使用subplot2grid制作出来,直到你感到满意!

    我们将继续使用subplot2grid,将它应用到我们已经逐步建立的代码中,我们将在下一个教程中继续。

    第二十章 将子图应用于我们的图表

    在这个 Matplotlib 教程中,我们将处理我们以前教程的代码,并实现上一个教程中的子图配置。 我们的起始代码是这样:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((1,1), (0,0))
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
     
        candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax1.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax1.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax1.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax1.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text
    ##    ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.title(stock)
        #plt.legend()
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    一个主要的改动是修改轴域的定义:

    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
    plt.title(stock)
    ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
    plt.xlabel('Date')
    plt.ylabel('Price')
    ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
    

    现在,ax2是我们实际上在绘制的股票价格数据。 顶部和底部图表将作为指标信息。

    在我们绘制数据的代码中,我们需要将ax1更改为ax2

    candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
    for label in ax2.xaxis.get_ticklabels():
        label.set_rotation(45)
     
    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    ax2.xaxis.set_major_locator(mticker.MaxNLocator(10))
    ax2.grid(True)
     
    bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                 xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
    

    更改之后,代码为:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
        plt.xlabel('Date')
        plt.ylabel('Price')
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
     
        candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax2.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax2.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax1.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text
    ##    ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
     
        #
        #plt.legend()
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    结果为:

    img

    第二十一章 更多指标数据

    在这篇 Matplotlib 教程中,我们介绍了添加一些简单的函数来计算数据,以便我们填充我们的轴域。 一个是简单的移动均值,另一个是简单的价格 HML 计算。

    这些新函数是:

    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
    

    你不需要太过专注于理解移动均值的工作原理,我们只是对样本数据来计算它,以便可以学习更多自定义 Matplotlib 的东西。

    我们还想在脚本顶部为移动均值定义一些值:

    MA1 = 10
    MA2 = 30
    

    下面,在我们的graph_data函数中:

    ma1 = moving_average(closep,MA1)
    ma2 = moving_average(closep,MA2)
    start = len(date[MA2-1:])
     
    h_l = list(map(high_minus_low, highp, lowp))
    

    在这里,我们计算两个移动均值和 HML。

    我们还定义了一个『起始』点。 我们这样做是因为我们希望我们的数据排成一行。 例如,20 天的移动均值需要 20 个数据点。 这意味着我们不能在第 5 天真正计算 20 天的移动均值。 因此,当我们计算移动均值时,我们会失去一些数据。 为了处理这种数据的减法,我们使用起始变量来计算应该有多少数据。 这里,我们可以安全地使用[-start:]绘制移动均值,并且如果我们希望的话,对所有绘图进行上述步骤来排列数据。

    接下来,我们可以在ax1上绘制 HML,通过这样:

    ax1.plot_date(date,h_l,'-')
    

    最后我们可以通过这样向ax3添加移动均值:

    ax3.plot(date[-start:], ma1[-start:])
    ax3.plot(date[-start:], ma2[-start:])
    

    我们的完整代码,包括增加我们所用的时间范围:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
        plt.xlabel('Date')
        plt.ylabel('Price')
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
        ax1.plot_date(date,h_l,'-')
     
     
        candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax2.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax2.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
     
     
        ax3.plot(date[-start:], ma1[-start:])
        ax3.plot(date[-start:], ma2[-start:])
     
     
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    img

    第二十二章 自定义填充、修剪和清除

    欢迎阅读另一个 Matplotlib 教程! 在本教程中,我们将清除图表,然后再做一些自定义。

    我们当前的代码是:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
        plt.xlabel('Date')
        plt.ylabel('Price')
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
        ax1.plot_date(date,h_l,'-')
     
     
        candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
        for label in ax2.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax2.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
     
     
        ax3.plot(date[-start:], ma1[-start:])
        ax3.plot(date[-start:], ma2[-start:])
     
     
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    现在我认为向我们的移动均值添加自定义填充是一个很好的主意。 移动均值通常用于说明价格趋势。 这个想法是,你可以计算一个快速和一个慢速的移动均值。 一般来说,移动均值用于使价格变得『平滑』。 他们总是『滞后』于价格,但是我们的想法是计算不同的速度。 移动均值越大就越『慢』。 所以这个想法是,如果『较快』的移动均值超过『较慢』的均值,那么价格就会上升,这是一件好事。 如果较快的 MA 从较慢的 MA 下方穿过,则这是下降趋势并且通常被视为坏事。 我的想法是在快速和慢速 MA 之间填充,『上升』趋势为绿色,然后下降趋势为红色。 方法如下:

    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                     where=(ma1[-start:] < ma2[-start:]),
                     facecolor='r', edgecolor='r', alpha=0.5)
     
    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                     where=(ma1[-start:] > ma2[-start:]),
                     facecolor='g', edgecolor='g', alpha=0.5)
    

    下面,我们会碰到一些我们可解决的问题:

    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))
     
    for label in ax3.xaxis.get_ticklabels():
        label.set_rotation(45)
     
    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax2.get_xticklabels(), visible=False)
    

    这里,我们剪切和粘贴ax2日期格式,然后我们将x刻度标签设置为false,去掉它们!

    我们还可以通过在轴域定义中执行以下操作,为每个轴域提供自定义标签:

    fig = plt.figure()
    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
    plt.title(stock)
    ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
    plt.xlabel('Date')
    plt.ylabel('Price')
    ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
    

    接下来,我们可以看到,我们y刻度有许多数字,经常互相覆盖。 我们也看到轴之间互相重叠。 我们可以这样:

    ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='lower'))
    

    所以,这里发生的是,我们通过首先将nbins设置为 5 来修改我们的y轴对象。这意味着我们显示的标签最多为 5 个。然后我们还可以『修剪』标签,因此,在我们这里, 我们修剪底部标签,这会使它消失,所以现在不会有任何文本重叠。 我们仍然可能打算修剪ax2的顶部标签,但这里是我们目前为止的源代码:

    当前的源码:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        plt.ylabel('H-L')
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)
        plt.ylabel('Price')
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
        plt.ylabel('MAvgs')
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
     
        ax1.plot_date(date,h_l,'-')
        ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='lower'))
     
     
        candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')
     
     
     
     
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
     
     
        ax3.plot(date[-start:], ma1[-start:], linewidth=1)
        ax3.plot(date[-start:], ma2[-start:], linewidth=1)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] < ma2[-start:]),
                         facecolor='r', edgecolor='r', alpha=0.5)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] > ma2[-start:]),
                         facecolor='g', edgecolor='g', alpha=0.5)
     
        ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))
     
        for label in ax3.xaxis.get_ticklabels():
            label.set_rotation(45)
     
        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax2.get_xticklabels(), visible=False)
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    img

    看起来好了一些,但是仍然有一些东西需要清除。

    第二十三章 共享 X 轴

    在这个 Matplotlib 数据可视化教程中,我们将讨论sharex选项,它允许我们在图表之间共享x轴。将sharex看做『复制 x』也许更好。

    在我们开始之前,首先我们要做些修剪并在另一个轴上设置最大刻度数,如下所示:

    ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))
    

    以及

    ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))
    

    现在,让我们共享所有轴域之间的x轴。 为此,我们需要将其添加到轴域定义中:

    fig = plt.figure()
    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
    plt.title(stock)
    plt.ylabel('H-L')
    ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)
    plt.ylabel('Price')
    ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)
    plt.ylabel('MAvgs')
    

    上面,对于ax2ax3,我们添加一个新的参数,称为sharex,然后我们说,我们要与ax1共享x轴。

    使用这种方式,我们可以加载图表,然后我们可以放大到一个特定的点,结果将是这样:

    img

    所以这意味着所有轴域沿着它们的x轴一起移动。 这很酷吧!

    接下来,让我们将[-start:]应用到所有数据,所以所有轴域都起始于相同地方。 我们最终的代码为:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        plt.ylabel('H-L')
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)
        plt.ylabel('Price')
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)
        plt.ylabel('MAvgs')
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
     
        ax1.plot_date(date[-start:],h_l[-start:],'-')
        ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))
     
     
        candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')
     
     
     
        ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
     
     
        ax3.plot(date[-start:], ma1[-start:], linewidth=1)
        ax3.plot(date[-start:], ma2[-start:], linewidth=1)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] < ma2[-start:]),
                         facecolor='r', edgecolor='r', alpha=0.5)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] > ma2[-start:]),
                         facecolor='g', edgecolor='g', alpha=0.5)
     
        ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))
     
        for label in ax3.xaxis.get_ticklabels():
            label.set_rotation(45)
     
     
     
        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax2.get_xticklabels(), visible=False)
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('EBAY')
    

    下面我们会讨论如何创建多个y轴。

    第二十四章 多个 Y 轴

    在这篇 Matplotlib 教程中,我们将介绍如何在同一子图上使用多个 Y 轴。 在我们的例子中,我们有兴趣在同一个图表及同一个子图上绘制股票价格和交易量。

    为此,首先我们需要定义一个新的轴域,但是这个轴域是ax2仅带有x轴的『双生子』。

    这足以创建轴域了。我们叫它ax2v,因为这个轴域是ax2加交易量。

    现在,我们在轴域上定义绘图,我们将添加:

    ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)
    

    我们在 0 和当前交易量之间填充,给予它蓝色的前景色,然后给予它一个透明度。 我们想要应用幽冥毒,以防交易量最终覆盖其它东西,所以我们仍然可以看到这两个元素。

    所以,到现在为止,我们的代码为:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        plt.ylabel('H-L')
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)
        plt.ylabel('Price')
        ax2v = ax2.twinx()
     
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)
        plt.ylabel('MAvgs')
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
     
        ax1.plot_date(date[-start:],h_l[-start:],'-')
        ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))
     
     
        candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')
     
     
     
        ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
        ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)
     
     
        ax3.plot(date[-start:], ma1[-start:], linewidth=1)
        ax3.plot(date[-start:], ma2[-start:], linewidth=1)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] < ma2[-start:]),
                         facecolor='r', edgecolor='r', alpha=0.5)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] > ma2[-start:]),
                         facecolor='g', edgecolor='g', alpha=0.5)
     
        ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))
     
        for label in ax3.xaxis.get_ticklabels():
            label.set_rotation(45)
     
     
     
        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax2.get_xticklabels(), visible=False)
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
    graph_data('GOOG')
    

    会生成:

    img

    太棒了,到目前为止还不错。 接下来,我们可能要删除新y轴上的标签,然后我们也可能不想让交易量占用太多空间。 没问题:

    首先:

    ax2v.axes.yaxis.set_ticklabels([])
    

    上面将y刻度标签设置为一个空列表,所以不会有任何标签了。

    译者注:所以将标签删除之后,添加新轴的意义是什么?直接在原轴域上绘图就可以了。

    接下来,我们可能要将网格设置为false,使轴域上不会有双网格:

    ax2v.grid(False)
    

    最后,为了处理交易量占用很多空间,我们可以做以下操作:

    ax2v.set_ylim(0, 3*volume.max())
    

    所以这设置y轴显示范围从 0 到交易量的最大值的 3 倍。 这意味着,在最高点,交易量最多可占据图形的33%。 所以,增加volume.max的倍数越多,空间就越小/越少。

    现在,我们的图表为:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure()
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        plt.ylabel('H-L')
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)
        plt.ylabel('Price')
        ax2v = ax2.twinx()
     
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)
        plt.ylabel('MAvgs')
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
     
        ax1.plot_date(date[-start:],h_l[-start:],'-')
        ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))
     
     
        candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')
     
     
     
        ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+5, closep[-1]), bbox=bbox_props)
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
        ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)
        ax2v.axes.yaxis.set_ticklabels([])
        ax2v.grid(False)
        ax2v.set_ylim(0, 3*volume.max())
     
     
     
        ax3.plot(date[-start:], ma1[-start:], linewidth=1)
        ax3.plot(date[-start:], ma2[-start:], linewidth=1)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] < ma2[-start:]),
                         facecolor='r', edgecolor='r', alpha=0.5)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] > ma2[-start:]),
                         facecolor='g', edgecolor='g', alpha=0.5)
     
        ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))
     
        for label in ax3.xaxis.get_ticklabels():
            label.set_rotation(45)
     
     
     
        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax2.get_xticklabels(), visible=False)
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
        plt.show()
     
     
     
    graph_data('GOOG')
    

    到这里,我们差不多完成了。 这里唯一的缺陷是一个好的图例。 一些线条是显而易见的,但人们可能会好奇移动均值的参数是什么,我们这里是 10 和 30。 添加自定义图例是下一个教程中涉及的内容。

    第二十五章 自定义图例

    在这篇 Matplotlib 教程中,我们将讨论自定义图例。 我们已经介绍了添加图例的基础知识

    图例的主要问题通常是图例阻碍了数据的展示。 这里有几个选项。 一个选项是将图例放在轴域外,但是我们在这里有多个子图,这是非常困难的。 相反,我们将使图例稍微小一点,然后应用一个透明度。

    首先,为了创建一个图例,我们需要向我们的数据添加我们想要显示在图例上的标签。

    ax1.plot_date(date[-start:],h_l[-start:],'-', label='H-L')
    ...
    ax2v.plot([],[], color='#0079a3', alpha=0.4, label='Volume')
    ...
    ax3.plot(date[-start:], ma1[-start:], linewidth=1, label=(str(MA1)+'MA'))
    ax3.plot(date[-start:], ma2[-start:], linewidth=1, label=(str(MA2)+'MA'))
    

    请注意,我们通过创建空行为交易量添加了标签。 请记住,我们不能对任何填充应用标签,所以这就是我们添加这个空行的原因。

    现在,我们可以在右下角添加图例,通过在plt.show()之前执行以下操作:

    ax1.legend()
    ax2v.legend()
    ax3.legend()
    

    会生成:

    img

    所以,我们可以看到,图例还是占用了一些位置。 让我们更改位置,大小并添加透明度:

    ax1.legend()
    leg = ax1.legend(loc=9, ncol=2,prop={'size':11})
    leg.get_frame().set_alpha(0.4)
     
    ax2v.legend()
    leg = ax2v.legend(loc=9, ncol=2,prop={'size':11})
    leg.get_frame().set_alpha(0.4)
     
    ax3.legend()
    leg = ax3.legend(loc=9, ncol=2,prop={'size':11})
    leg.get_frame().set_alpha(0.4)
    

    所有的图例位于位置 9(上中间)。 有很多地方可放置图例,我们可以为参数传入不同的位置号码,来看看它们都位于哪里。 ncol参数允许我们指定图例中的列数。 这里只有一列,如果图例中有 2 个项目,他们将堆叠在一列中。 最后,我们将尺寸规定为更小。 之后,我们对整个图例应用0.4的透明度。

    现在我们的结果为:

    img

    完整的代码为:

    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.ticker as mticker
    from matplotlib.finance import candlestick_ohlc
    from matplotlib import style
     
    import numpy as np
    import urllib
    import datetime as dt
     
    style.use('fivethirtyeight')
    print(plt.style.available)
     
    print(plt.__file__)
     
    MA1 = 10
    MA2 = 30
     
    def moving_average(values, window):
        weights = np.repeat(1.0, window)/window
        smas = np.convolve(values, weights, 'valid')
        return smas
     
    def high_minus_low(highs, lows):
        return highs-lows
     
     
    def bytespdate2num(fmt, encoding='utf-8'):
        strconverter = mdates.strpdate2num(fmt)
        def bytesconverter(b):
            s = b.decode(encoding)
            return strconverter(s)
        return bytesconverter
     
     
    def graph_data(stock):
     
        fig = plt.figure(facecolor='#f0f0f0')
        ax1 = plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)
        plt.title(stock)
        plt.ylabel('H-L')
        ax2 = plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)
        plt.ylabel('Price')
        ax2v = ax2.twinx()
     
        ax3 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)
        plt.ylabel('MAvgs')
     
     
        stock_price_url = 'http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'
        source_code = urllib.request.urlopen(stock_price_url).read().decode()
        stock_data = []
        split_source = source_code.split('
    ')
        for line in split_source:
            split_line = line.split(',')
            if len(split_line) == 6:
                if 'values' not in line and 'labels' not in line:
                    stock_data.append(line)
     
     
        date, closep, highp, lowp, openp, volume = np.loadtxt(stock_data,
                                                              delimiter=',',
                                                              unpack=True,
                                                              converters={0: bytespdate2num('%Y%m%d')})
     
        x = 0
        y = len(date)
        ohlc = []
     
        while x < y:
            append_me = date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]
            ohlc.append(append_me)
            x+=1
     
        ma1 = moving_average(closep,MA1)
        ma2 = moving_average(closep,MA2)
        start = len(date[MA2-1:])
     
        h_l = list(map(high_minus_low, highp, lowp))
     
     
        ax1.plot_date(date[-start:],h_l[-start:],'-', label='H-L')
        ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))
     
     
        candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')
     
     
     
        ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))
        ax2.grid(True)
     
        bbox_props = dict(boxstyle='round',fc='w', ec='k',lw=1)
     
        ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),
                     xytext = (date[-1]+4, closep[-1]), bbox=bbox_props)
     
    ##    # Annotation example with arrow
    ##    ax2.annotate('Bad News!',(date[11],highp[11]),
    ##                 xytext=(0.8, 0.9), textcoords='axes fraction',
    ##                 arrowprops = dict(facecolor='grey',color='grey'))
    ##
    ##    
    ##    # Font dict example
    ##    font_dict = {'family':'serif',
    ##                 'color':'darkred',
    ##                 'size':15}
    ##    # Hard coded text 
    ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)
     
        ax2v.plot([],[], color='#0079a3', alpha=0.4, label='Volume')
        ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)
        ax2v.axes.yaxis.set_ticklabels([])
        ax2v.grid(False)
        ax2v.set_ylim(0, 3*volume.max())
     
     
     
        ax3.plot(date[-start:], ma1[-start:], linewidth=1, label=(str(MA1)+'MA'))
        ax3.plot(date[-start:], ma2[-start:], linewidth=1, label=(str(MA2)+'MA'))
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] < ma2[-start:]),
                         facecolor='r', edgecolor='r', alpha=0.5)
     
        ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],
                         where=(ma1[-start:] > ma2[-start:]),
                         facecolor='g', edgecolor='g', alpha=0.5)
     
        ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))
        ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))
     
        for label in ax3.xaxis.get_ticklabels():
            label.set_rotation(45)
     
     
     
        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax2.get_xticklabels(), visible=False)
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
     
        ax1.legend()
        leg = ax1.legend(loc=9, ncol=2,prop={'size':11})
        leg.get_frame().set_alpha(0.4)
     
        ax2v.legend()
        leg = ax2v.legend(loc=9, ncol=2,prop={'size':11})
        leg.get_frame().set_alpha(0.4)
     
        ax3.legend()
        leg = ax3.legend(loc=9, ncol=2,prop={'size':11})
        leg.get_frame().set_alpha(0.4)
     
        plt.show()
        fig.savefig('google.png', facecolor=fig.get_facecolor())
     
     
    graph_data('GOOG')
    

    现在我们可以看到图例,但也看到了图例下的任何信息。 还要注意额外函数fig.savefig。 这是自动保存图形的图像的方式。 我们还可以设置所保存的图形的前景色,使背景不是白色的,如我们的例子所示。

    这就是目前为止,我想要显示的典型 Matplotlib 图表。 接下来,我们将涉及Basemap,它是一个 Matplotlib 扩展,用于绘制地理位置,然后我打算讲解 Matplotlib 中的 3D 图形。

    第二十六章 Basemap 地理绘图

    在这个 Matplotlib 教程中,我们将涉及地理绘图模块BasemapBasemap是 Matplotlib 的扩展。

    为了使用Basemap,我们首先需要安装它。 为了获得Basemap,你可以从这里获取:http://matplotlib.org/basemap/users/download.html,或者你可以访问http://www.lfd.uci.edu/~gohlke/pythonlibs/。

    如果你在安装Basemap时遇到问题,请查看pip安装教程

    一旦你安装了Basemap,你就可以创建地图了。 首先,让我们投影一个简单的地图。 为此,我们需要导入Basemappyplot,创建投影,至少绘制某种轮廓或数据,然后我们可以显示图形。

    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill')
    m.drawcoastlines()
    plt.show()
    

    上面的代码结果如下:

    img

    这是使用 Miller 投影完成的,这只是许多Basemap投影选项之一。

    第二十七章 Basemap 自定义

    在这篇 Matplotlib 教程中,我们继续使用Basemap地理绘图扩展。 我们将展示一些我们可用的自定义选项。

    首先,从上一个教程中获取我们的起始代码:

    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill')
    m.drawcoastlines()
    plt.show()
    

    我们可以从放大到特定区域来开始:

    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill',
                llcrnrlat = -40,
                llcrnrlon = -40,
                urcrnrlat = 50,
                urcrnrlon = 75)
    m.drawcoastlines()
    plt.show()
    

    这里的参数是:

    • llcrnrlat – 左下角的纬度
    • llcrnrlon – 左下角的经度
    • urcrnrlat – 右上角的纬度
    • urcrnrlon – 右上角的经度

    此外,坐标需要转换,其中西经和南纬坐标是负值,北纬和东经坐标是正值。

    使用这些坐标,Basemap会选择它们之间的区域。

    img

    下面,我们要使用一些东西,类似:

    m.drawcountries(linewidth=2)
    

    这会画出国家,并使用线宽为 2 的线条生成分界线。

    另一个选项是:

    m.drawstates(color='b')
    

    这会用蓝色线条画出州。

    你也可以执行:

    m.drawcounties(color='darkred')
    

    这会画出国家。

    所以,我们的代码是:

    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill',
                llcrnrlat = -90,
                llcrnrlon = -180,
                urcrnrlat = 90,
                urcrnrlon = 180)
     
    m.drawcoastlines()
    m.drawcountries(linewidth=2)
    m.drawstates(color='b')
    m.drawcounties(color='darkred')
     
    plt.title('Basemap Tutorial')
    plt.show()
    

    img

    很难说,但我们定义了美国的区县的线条。 我们可以使用放大镜放大Basemap图形,就像其他图形那样,会生成:

    img

    另一个有用的选项是Basemap调用中的『分辨率』选项。

    m = Basemap(projection='mill',
                llcrnrlat = -90,
                llcrnrlon = -180,
                urcrnrlat = 90,
                urcrnrlon = 180,
                resolution='l')
    

    分辨率的选项为:

    • c – 粗糙
    • l – 低
    • h – 高
    • f – 完整

    对于更高的分辨率,你应该放大到很大,否则这可能只是浪费。

    另一个选项是使用etopo()绘制地形,如:

    m.etopo()
    

    使用drawcountries方法绘制此图形会生成:

    img

    最后,有一个蓝色的大理石版本,你可以调用:

    m.bluemarble()
    

    会生成:

    img

    目前为止的代码:

    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill',
                llcrnrlat = -90,
                llcrnrlon = -180,
                urcrnrlat = 90,
                urcrnrlon = 180,
                resolution='l')
     
    m.drawcoastlines()
    m.drawcountries(linewidth=2)
    ##m.drawstates(color='b')
    ##m.drawcounties(color='darkred')
    #m.fillcontinents()
    #m.etopo()
    m.bluemarble()
     
    plt.title('Basemap Tutorial')
    plt.show()
    ​````
     
     
     
     
     
    <div class="se-preview-section-delimiter"></div>
     
    # 第二十八章 在 Basemap 中绘制坐标
     
    欢迎阅读另一个 Matplotlib Basemap 教程。 在本教程中,我们将介绍如何绘制单个坐标,以及如何在地理区域中连接这些坐标。
     
    首先,我们将从一些基本的起始数据开始:
     
     
     
     
     
    <div class="se-preview-section-delimiter"></div>
     
    ​```py
    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill',
                llcrnrlat = 25,
                llcrnrlon = -130,
                urcrnrlat = 50,
                urcrnrlon = -60,
                resolution='l')
     
    m.drawcoastlines()
    m.drawcountries(linewidth=2)
    m.drawstates(color='b')
    

    接下来,我们可以绘制坐标,从获得它们的实际坐标开始。 记住,南纬和西经坐标需要转换为负值。 例如,纽约市是北纬40.7127西经74.0059。 我们可以在我们的程序中定义这些坐标,如:

    NYClat, NYClon = 40.7127, -74.0059
    

    之后我们将这些转换为要绘制的xy坐标。

    xpt, ypt = m(NYClon, NYClat)
    

    注意这里,我们现在已经将坐标顺序翻转为lon, lat(纬度,经度)。 坐标通常以lat, lon顺序给出。 然而,在图形中,lat, long转换为y, x,我们显然不需要。 在某些时候,你必须翻转它们。 不要忘记这部分!

    最后,我们可以绘制如下的坐标:

    m.plot(xpt, ypt, 'c*', markersize=15)
    

    这个图表上有一个青色的星,大小为 15。更多标记类型请参阅:Matplotlib 标记文档

    接下来,让我们再画一个位置,洛杉矶,加利福尼亚:

    LAlat, LAlon = 34.05, -118.25
    xpt, ypt = m(LAlon, LAlat)
    m.plot(xpt, ypt, 'g^', markersize=15)
    

    这次我们画出一个绿色三角,执行代码会生成:

    img

    如果我们想连接这些图块怎么办?原来,我们可以像其它 Matplotlib 图表那样实现它。

    首先,我们将那些xptypt坐标保存到列表,类似这样的东西:

    xs = []
    ys = []
     
    NYClat, NYClon = 40.7127, -74.0059
    xpt, ypt = m(NYClon, NYClat)
    xs.append(xpt)
    ys.append(ypt)
    m.plot(xpt, ypt, 'c*', markersize=15)
     
    LAlat, LAlon = 34.05, -118.25
    xpt, ypt = m(LAlon, LAlat)
    xs.append(xpt)
    ys.append(ypt)
    m.plot(xpt, ypt, 'g^', markersize=15)
     
    m.plot(xs, ys, color='r', linewidth=3, label='Flight 98')
    

    会生成:

    img

    太棒了。有时我们需要以圆弧连接图上的两个坐标。如何实现呢?

    m.drawgreatcircle(NYClon, NYClat, LAlon, LAlat, color='c', linewidth=3, label='Arc')
    

    我们的完整代码为:

    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
     
    m = Basemap(projection='mill',
                llcrnrlat = 25,
                llcrnrlon = -130,
                urcrnrlat = 50,
                urcrnrlon = -60,
                resolution='l')
     
    m.drawcoastlines()
    m.drawcountries(linewidth=2)
    m.drawstates(color='b')
    #m.drawcounties(color='darkred')
    #m.fillcontinents()
    #m.etopo()
    #m.bluemarble()
     
    xs = []
    ys = []
     
    NYClat, NYClon = 40.7127, -74.0059
    xpt, ypt = m(NYClon, NYClat)
    xs.append(xpt)
    ys.append(ypt)
    m.plot(xpt, ypt, 'c*', markersize=15)
     
    LAlat, LAlon = 34.05, -118.25
    xpt, ypt = m(LAlon, LAlat)
    xs.append(xpt)
    ys.append(ypt)
    m.plot(xpt, ypt, 'g^', markersize=15)
     
    m.plot(xs, ys, color='r', linewidth=3, label='Flight 98')
    m.drawgreatcircle(NYClon, NYClat, LAlon, LAlat, color='c', linewidth=3, label='Arc')
     
     
    plt.legend(loc=4)
    plt.title('Basemap Tutorial')
    plt.show()
    

    结果为:

    img

    这就是Basemap的全部了,下一章关于 Matplotlib 的 3D 绘图。

    第二十九章 3D 绘图

    您好,欢迎阅读 Matplotlib 教程中的 3D 绘图。 Matplotlib 已经内置了三维图形,所以我们不需要再下载任何东西。 首先,我们需要引入一些完整的模块:

    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    

    使用axes3d是因为它需要不同种类的轴域,以便在三维中实际绘制一些东西。 下面:

    fig = plt.figure()
    ax1 = fig.add_subplot(111, projection='3d')
    

    在这里,我们像通常一样定义图形,然后我们将ax1定义为通常的子图,只是这次使用 3D 投影。 我们需要这样做,以便提醒 Matplotlib 我们要提供三维数据。

    现在让我们创建一些 3D 数据:

    x = [1,2,3,4,5,6,7,8,9,10]
    y = [5,6,7,8,2,5,6,3,7,2]
    z = [1,2,6,3,2,7,3,3,7,2]
    

    接下来,我们绘制它。 首先,让我们展示一个简单的线框示例:

    ax1.plot_wireframe(x,y,z)
    

    最后:

    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
     
    plt.show()
    

    我们完整的代码是:

    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    from matplotlib import style
     
    style.use('fivethirtyeight')
     
    fig = plt.figure()
    ax1 = fig.add_subplot(111, projection='3d')
     
    x = [1,2,3,4,5,6,7,8,9,10]
    y = [5,6,7,8,2,5,6,3,7,2]
    z = [1,2,6,3,2,7,3,3,7,2]
     
    ax1.plot_wireframe(x,y,z)
     
    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
     
    plt.show()
    

    结果为(包括所用的样式):

    img

    这些 3D 图形可以进行交互。 首先,您可以使用鼠标左键单击并拖动来移动图形。 您还可以使用鼠标右键单击并拖动来放大或缩小。

    第三十章 3D 散点图

    欢迎阅读另一个 3D Matplotlib 教程,会涉及如何绘制三维散点图。

    绘制 3D 散点图非常类似于通常的散点图以及 3D 线框图。

    一个简单示例:

    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    from matplotlib import style
     
    style.use('ggplot')
     
    fig = plt.figure()
    ax1 = fig.add_subplot(111, projection='3d')
     
    x = [1,2,3,4,5,6,7,8,9,10]
    y = [5,6,7,8,2,5,6,3,7,2]
    z = [1,2,6,3,2,7,3,3,7,2]
     
    x2 = [-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]
    y2 = [-5,-6,-7,-8,-2,-5,-6,-3,-7,-2]
    z2 = [1,2,6,3,2,7,3,3,7,2]
     
    ax1.scatter(x, y, z, c='g', marker='o')
    ax1.scatter(x2, y2, z2, c ='r', marker='o')
     
    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
     
    plt.show()
    

    结果为:

    img

    要记住你可以修改这些绘图的大小和标记,就像通常的散点图那样。

    第三十一章 3D 条形图

    在这个 Matplotlib 教程中,我们要介绍 3D 条形图。 3D 条形图是非常独特的,因为它允许我们绘制多于 3 个维度。 不,你不能超过第三个维度来绘制,但你可以绘制多于 3 个维度。

    对于条形图,你需要拥有条形的起点,条形的高度和宽度。 但对于 3D 条形图,你还有另一个选项,就是条形的深度。 大多数情况下,条形图从轴上的条形平面开始,但是你也可以通过打破此约束来添加另一个维度。 然而,我们会让它非常简单:

    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib import style
    style.use('ggplot')
     
    fig = plt.figure()
    ax1 = fig.add_subplot(111, projection='3d')
     
    x3 = [1,2,3,4,5,6,7,8,9,10]
    y3 = [5,6,7,8,2,5,6,3,7,2]
    z3 = np.zeros(10)
     
    dx = np.ones(10)
    dy = np.ones(10)
    dz = [1,2,3,4,5,6,7,8,9,10]
     
    ax1.bar3d(x3, y3, z3, dx, dy, dz)
     
     
    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
     
    plt.show()
    

    注意这里,我们必须定义xyz,然后是 3 个维度的宽度、高度和深度。 这会生成:

    img

    第三十二章 总结

    欢迎阅读最后的 Matplotlib 教程。 在这里我们将整理整个系列,并显示一个稍微更复杂的 3D 线框图:

    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib import style
    style.use('ggplot')
     
    fig = plt.figure()
    ax1 = fig.add_subplot(111, projection='3d')
     
    x, y, z = axes3d.get_test_data()
     
    print(axes3d.__file__)
    ax1.plot_wireframe(x,y,z, rstride = 3, cstride = 3)
     
    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
     
    plt.show()
    

    img

    如果你从一开始就关注这个教程的话,那么你已经学会了 Matplotlib 提供的大部分内容。 你可能不相信,但Matplotlib 仍然可以做很多其他的事情! 请继续学习,你可以随时访问 Matplotlib.org,并查看示例和图库页面。

    如果你发现自己大量使用 Matplotlib,请考虑捐助给 John Hunter Memorial 基金

    空间曲面的画法

    # 二次抛物面 z = x^2 + y^2
     
    x = np.linspace(-10, 10, 101)
    y = x
    x, y = np.meshgrid(x, y)
    z = x ** 2 + y ** 2
    ax = plot.subplot(111, projection='3d')
    ax.plot_wireframe(x, y, z)
    plot.show()
    

    img

    # 半径为 1 的球
     
    t = np.linspace(0, np.pi * 2, 100)
    s = np.linspace(0, np.pi, 100)
    t, s = np.meshgrid(t, s)
    x = np.cos(t) * np.sin(s)
    y = np.sin(t) * np.sin(s)
    z = np.cos(s)
    ax = plot.subplot(111, projection='3d')
    ax.plot_wireframe(x, y, z)
    plot.show()
    

    img

  • 相关阅读:
    monkeyrunner之夜神模拟器的安装与使用(二)
    monkeyrunner之安卓开发环境搭建(一)
    MySQL 返回未包含在group by中的列
    MySQL数据库初体验
    MongoDB安装
    关于数据库你必须知道的事~
    PostgreSQL中的MVCC 事务隔离
    深入浅出MySQL之索引为什么要下推?
    Java集合篇:Map集合的几种遍历方式及性能测试
    Oracle11g:数据库恢复总结
  • 原文地址:https://www.cnblogs.com/remixnameless/p/13215679.html
Copyright © 2011-2022 走看看