【转】matplotlib:使用matplotlib进行python绘图(指南)

在python中画图肯定少不了matplotlib库,几行代码就可以将画出基本图来,但是大部分时候还是要去查阅文档或者去stackover flow查找,因为这个库真的太庞大了。最近刚好看到一篇文章,对matplotlib的一些概念解释的很好,也有对应的例子,所以根据文章做个简化版翻译,更详细的内容可以查看原文。

matplotlib对象的层次结果

使用过matplotlib,应该都用过这行代码plt.plot([1,2,3])。这一行隐藏了一个事实:一个plot实际上是嵌套python对象的层次结果,也就是说每个图下面都有一个matplotlib对象的树形结构。

如下图所示,一个Figure对象是matplotlib图像最外层的一个容器,可以包含一个或多个Axes对象(实际绘图的盒状容器)。容易引起混淆的一个原因:Axes实际上是一个独立的plot或者graph,而不是我们所认为的axis/轴的复数。Axes对象之下就是图表的“元素”,例如:刻度线、线条、图例、文本框、刻度、标签等,都是Axes对象可操作的python对象。

文中代码都是在jupyter notebook中运行的,最后的注释内容是输出结果,因为懒得截图了。。。

通过例子来重新认识matplotlib吧:

1
2
3
4
5
6
7
8
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

fig, _ = plt.subplots()
type(fig)

# matplotlib.figure.Figure

例子中从plt.subplots()中创建了两个变量,一个fig是顶级的Figure对象,另一个是暂时不需要用到的“一次性”变量,用下划线表示。使用属性表示法,可以轻松遍历层级结构并查看第一个Axes对象的y轴的第一个刻度

1
2
3
4
one_tick = fig.axes[0].yaxis.get_major_ticks()[0]
one_tick

# <matplotlib.axis.YTick at 0x7f0b569acac8>

例子中,fig(一个Figure类实例)有多个Axes(一个列表,这里取第一个)。每个Axes都有一个yaxis和xaxis,每个都有一个“major ticks”的集合,例子中取第一个。

各个部分可以参考下图:

上图的代码实现

stateful和stateless

在进一步了解matplotlib之前,需要明白stateful(基于状态,状态机)和无状态(面向对象,OO)接口之间的区别。

几乎所有来自pyplot的函数,如plt.plot(),要么隐式地指向现有的Figure对象和Axes对象,要么不存在时直接创建新的Figure和Axes对象。matplotlib文档是这样说的:

【使用pyplot】简单的函数是用来为当前Figure对象的Axes图像添加元素的(线条,图片,文本等)【强调添加】

“plt.plot()是一个隐式跟踪当前数字的状态机界面!”在英语中,这意味着:

  1. 有状态接口使用plt.plot()和其他顶级pyplot函数进行调用。在给定的时间内只有一个图或轴可以操作,不需要明确地引用它。
  2. 直接修改底层对象是面向对象的方法。我们通常通过调用Axes对象的方法来做到这一点,Axes对象是表示绘图本身的对象。

整个过程的流程就像下图所示:

用代码来表示上述流程:

1
2
3
4
5
6
7
8
# plt.plot()获取当前轴的一个流程表示,缩简版
def plot(*args, **kwargs):
ax = plt.gca()
return ax.plot(*args, **kwargs)

def gca(**kwargs):
"""获取当前Figure中的Axes对象"""
return plt.gcf().gca(**kwargs)

调用plt.plot()只是获取当前Figure的当前Axes对象,这也就是有状态接口总是“隐式跟踪”它想要引用的图。

pyplot是一系列函数的集合,这些函数实际上只是matplotlib面向对象接口的包装器。例如,plt.title()在面对对象方法中有相应的setter和getter方法:ax.set_title()和ax.get_title()。调用plt.title()等同于gca().set_title(s, *args, **kwargs),这行代码包含了以下操作:

  1. gca()获取并返回当前Axes
  2. set_title()是一个setter方法,用于设置该Axes对象的标题。这里我们就不需要明确指定任何对象的plt.title()

类似地,其他顶级函数,如plt.grid()plt.legend()plt.ylabels(),都是跟plt.title()一样的操作,先是获取当前Axes,gca()调用当前Axes的一些方法。

理解plt.subplots()

接下来主要是依赖于无状态/面向对象方法,这种方法可定制性更高,图形变得更加复杂更派得上用场。在面向对象使用单个Axes创建对象规定是使用plt.subplots()。这是面向对象使用pyplot创建Figure和Axes的唯一时间。

1
2
3
4
5
# 没有传递任何参数,默认是调用subplots(nrows=1, ncols=1),所以返回一个Figure和一个Axes
fig, ax = plt.subplots()
type(ax)

# matplotlib.axes._subplots.AxesSubplot

我们可以用ax的实例方法来操作绘图,接下来用三个时间序列的堆积面积图来说明

1
2
3
4
5
6
7
8
9
10
11
rng = np.arange(50)
rnd = np.random.randint(0, 10, size=(3, rng.size))
yrs = 1950 + rng

fig, ax = plt.subplots(figsize=(5,3))
ax.stackplot(yrs, rng+rnd, labels=['Eastasia', 'Eurasia', 'Oceania'])
ax.set_title('Combined debt growth over time')
ax.legend(loc='upper left')
ax.set_ylabel('Total debt')
ax.set_xlim(xmin=yrs[0], xmax=yrs[-1])
fig.tight_layout() # tight_layout()作为一个整体应用于图形对象来清理空白填充

如果是一个Figure中包含多个Axes/子图呢?下面用离散均匀分布绘制两个相关数组,跟上一个例子的不同点可以查看代码里的注释

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
x = np.random.randint(low=1, high=11, size=50)
y = x + np.random.randint(1, 5, size=x.size)
data = np.column_stack((x, y)) # column_stack将多个一维array作为column转为多维数组

# 1.创建一个Figure和两个Axes,也可以用fig, axs = plt.subplots(1, 2, figsize=(8,4))
# 2.分别处理两个子图
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
ax1.scatter(x, y, marker='o', c='r', edgecolor='b') # edgecolor设置marker边缘颜色
ax1.set_xlabel('$x$') # 加了$符号,x,y会变成斜体
ax1.set_ylabel('$y$')

# label定义两组数据对应的标签,DataFrame一般不需要,因为有columns
ax2.hist(data, bins=np.arange(data.min(), data.max()), label=('x', 'y'))
ax2.legend(loc=(0.7, 0.8))
ax2.set_title('Frequencies of $x$ and $y$')
ax2.yaxis.tick_right() # 将ax2的y轴tick设置在右边

需要注意的是,多个Axes都是包含在给定的Figure中的。如上面的例子,fig.axes能获取到所有的Axes对象列表

1
2
(fig.axes[0] is ax1, fig.axes[1] is ax2)  # 这里的fig.axes中的axes是小写
# (True, True)

我们还可以创建一个包含2*2Axes对象的网格图形Figure

1
2
3
4
5
6
7
8
9
10
11
12
13
14
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(7, 7))
type(ax) # 可以看到,ax是numpy的数组类型
# numpy.ndarray

ax
""" array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f0b541b29b0>,
<matplotlib.axes._subplots.AxesSubplot object at 0x7f0b547d2710>],
[<matplotlib.axes._subplots.AxesSubplot object at 0x7f0b548efc88>,
<matplotlib.axes._subplots.AxesSubplot object at 0x7f0b5451e240>]],
dtype=object)
"""

ax.shape
# (2, 2)

ax是数组类型,需要使用flatten函数将数组平展为一维的才能使用

1
2
3
ax1, ax2, ax3, ax4 = ax.flatten()
# or
((ax1, ax2), (ax3, ax4)) = ax

为了展示更多的用法,接下来使用加利福尼亚州的房价数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 获取数据。原文是通过代码下载解压数据的
housing = np.loadtxt('CaliforniaHousing/cal_housing.data', delimiter=',')
y = housing[:, -1]
pop, age = housing[:, [4, 7]].T # pop地区人口,age是房子的平均屋龄

def add_titlebox(ax, text):
"""在Axes中添加文本"""
ax.text(.55, .8, text,
horizontalalignment='center',
transform=ax.transAxes,
bbox=dict(facecolor='white', alpha=0.6),
fontsize=12.5)
return ax

# 创建带有多个Axes的Figure网格
gridsize = (3, 2)
fig = plt.figure(figsize=(12, 8))
ax1 = plt.subplot2grid(gridsize, (0, 0), colspan=2, rowspan=2) # ax1占用两行两列
ax2 = plt.subplot2grid(gridsize, (2, 0)) # ax2在第三行第一列
ax3 = plt.subplot2grid(gridsize, (2, 1)) # ax3在第三行第二列

# ax1是散点图,并带有colormap,其他两个图为直方图(查看数据分布)
ax1.set_title('Home value as a function of home age and area population', fontsize=14)
# c是color,y的序列映射到cmap,这样y的值与cmap的颜色可以一一对应上,cmap选择颜色变化,跟参数c结合使用
sctr = ax1.scatter(x=age, y=pop, c=y, cmap='RdYlGn')
plt.colorbar(sctr, ax=ax1, format='$%d')
ax2.hist(age, bins='auto') # bins可以是整数,序列或者‘auto’
ax3.hist(pop, bins='auto', log=True) # log标签显示为log形式

add_titlebox(ax2, 'Histogram: home age')
add_titlebox(ax3, 'Histogram: area population (log scl.)')

跟color map不一样的是,colorbar()是由Figure直接调用,而不是Axes对象,并且第一个参数是散点图的返回结果,它的作用就是将y值映射到color map。在y轴上,房价并没有太大的变化(颜色),x轴上则由明显变化,说明屋龄是房屋价值的一个更大的决定因素

屏幕后面的“Figures”

每次调用plt.subplots()或者plt.figure()(只创建Figure,没有Axes),都会在内存里新建一个Figure对象。在终端运行的结果跟原文一样,如果是在jupyter notebook则每一次返回id都不同。

1
2
3
4
5
6
7
8
9
10
>>> fig1, ax1 = plt.subplots()
>>> id(fig1)
4525375272
>>> id(plt.gcf()) # fig1是当前Figure对象
4525375272
>>> fig2, ax2 = plt.subplots()
>>> id(fig2) == id(plt.gcf()) # 当前Figure对象已经更改为fig2
True
>>> plt.get_fignums()
[1, 2]

可以通过plt.figure()获取到存在内存中的Figure

1
2
3
4
5
def get_all_figures():
return [plt.figure(i) for i in plt.get_fignums()]

get_all_figures()
# [<Figure size 640x480 with 1 Axes>, <Figure size 640x480 with 1 Axes>]

使用plt.close()方法则可以关闭对应的figure

1
2
plt.close(num) # 关闭序号为num的Figure
plt.close('all') # 关闭全部

imshow和matshow

除了plt.plot()可以绘图之后,还有其他的方法,如:imshow()和matshow()。后者是前者的封装。只要原始数值数组可以显示为彩色网格,就可以使用。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 创建两个数组
x = np.diag(np.arange(2, 12))[::-1]
x[np.diag_indices_from(x[::-1])] = np.arange(2, 12)
x2 = np.arange(x.size).reshape(x.shape)

# 使用字典参数,将所有的轴标签和刻度都不显示
sides = ('left', 'right', 'top', 'bottom')
nolabels = {s: False for s in sides}
nolabels.update({'label%s' % s: False for s in sides})

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

# 使用上下文管理器来禁用网格,并在每个Axes上使用matshow()
with plt.rc_context(rc={'axes.grid': False}):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.matshow(x)
img2 = ax2.matshow(x2, cmap='RdYlGn_r') # 加上_r就是将颜色反转
for ax in (ax1, ax2):
ax.tick_params(axis='both', which='both', **nolabels)
for i, j in zip(*x.nonzero()):
ax1.text(j, i, x[i,j], color='white', ha='center', va='center')

# 将colorbar作为图中的新轴
divider = make_axes_locatable(ax2)
cax = divider.append_axes('right', size='5%', pad=0)
plt.colorbar(img2, cax=cax, ax=[ax1, ax2])
fig.suptitle('Heatmaps with `Axes.matshow`', fontsize=16)

pandas中的绘图

pandas的DataFrame和Series上的plot()是plt.plot()的封装,所以在pandas绘图也是很方便的。如果DataFrame的索引是mdate,pandas调用gcf().autofmt_xdate()来获取当前的图并自动格式化x轴。plt.plot()是基于状态的方法,也就是隐式的知道当前的Figure和Axes,pandas也是遵循这点进行扩展的。

1
2
3
4
5
6
7
8
s = pd.Series(np.arange(5), index=list('abcde'))
ax = s.plot()

type(ax)
# matplotlib.axes._subplots.AxesSubplot

id(plt.gca()) == id(ax)
# True

接下来用一个时间序列的例子来看一下pandas中的画图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import matplotlib.transforms as mtransforms

vix = pd.read_csv('VIXCLS.csv', index_col=0, parse_dates=True, na_values='.',
infer_datetime_format=True,
squeeze=True).dropna()

ma = vix.rolling('90d').mean()
state = pd.cut(ma, bins=[-np.inf, 14, 18, 24, np.inf], labels=range(4))
cmap = plt.get_cmap('RdYlGn_r')
ma.plot(color='black', linewidth=1.5, marker='', figsize=(8, 4), label='VIX 90d MA')
ax = plt.gca()
ax.set_xlabel('')
ax.set_ylabel('90d moving average: CBOE VIX')
ax.set_title('Volatility Regime State')
ax.grid(False)
ax.set_xlim(xmin=ma.index[0], xmax=ma.index[-1])

trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
for i, color in enumerate(cmap([0.2, 0.4, 0.6, 0.8])):
ax.fill_between(ma.index, 0, 1, where=state==i,
facecolor=color, transform=trans)
ax.axhline(vix.mean(), linestyle='dashed', color='xkcd:dark grey',
alpha=0.6, label='Full-period mean', marker='')
ax.legend(loc='upper center')

上面的例子中有一些点需要了解一下:

  1. ma是VIX指数的90天移动平均线,衡量近期股市波动的市场预期。state是将移动平均线分类为不同的state。高VIX被表示市场中恐惧程度加剧。
  2. cmap是color map,matplotlib对象。它实质上是浮点数到RGBA颜色的映射。’_r是颜色反转。
  3. pandas调用ma.plot(),也就是调用了plt.plot()。因此为了继承面向对象,需要通过ax=plt.gca()来显示引用当前的Axes对象。
  4. 最后一段代码是创建了与每个状态对应的颜色填充块。cmap([0.2, 0.4, 0.6, 0.8])是指沿着color map的光谱,在20%,40%,60%,80%的位置获取RGBA序列。使用enumerate()将每个RGBA颜色跟state一一对应。

原文最后还给出了很多不错的链接推荐,值得学习。