np.stack(array, axis)
背景
在python的numpy库中,数组的stack堆叠是个很常见的操作,如何堆叠涉及到axis这个参数,本文以np.stack()函数为例,去讲解axis这个参数的解释。
语法
stack(arrays, axis=0, out=None)
Join a sequence of arrays along a new axis.
The `axis` parameter specifies the index of the new axis in the dimensions
of the result. For example, if ``axis=0`` it will be the first dimension
and if ``axis=-1`` it will be the last dimension.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
从官方文档的解释可以看出,stack()中的axis参数是在维度中加入了一个新轴,也即是stack堆叠的最终结果是返回一个array数组,堆叠后数组的维(轴)数比原始数组的维(轴)数要多一个维(轴),且多的那一个(维)轴上的数值为需要进行stack堆叠的array数组个数。例如要堆叠的数组是二维(轴),shape为(5,4),要堆叠的数组个数为3个,那么返回的结果一定是在二维(轴)上增加一维(轴)成三维(轴),且增加的维(轴)上的数值为3,那么新增的数值为3的维(轴)是增加在第0维(轴)上的(3,5,4)、第1维(轴)的(5,3,4)还是第2维(轴)的(5,4,3),而新增的维(轴)插在不同的位置也正是axis参数的真正含义。
测试
import numpy as np
a = np.array([1,2,3])
b = np.array([10,10,10])
>>> np.stack((a,b),axis=0)
array([[ 1, 2, 3],
[10, 10, 10]])
>>> np.stack((a,b),axis=1)
array([[ 1, 10],
[ 2, 10],
[ 3, 10]])
>>> np.stack((a,b),axis=2)
numpy.AxisError: axis 2 is out of bounds for array of dimension 2
可以看出由于原始数组的形状为(3,)是一维(轴)数组,使用stack堆叠2个数组后必然返回的是新增一维(轴)的二维(轴)数组形状为(2,3)或(3,2),如果axis=2则超出了二维(轴)的设定,因此不可能实现,再来看当axis=0则得到的是形状为(2,3),当axis=1则得到的是形状为(3,2)。由此可见,axis的值正是表示的新维(轴)新增的位置。
规则
不同的axis值得到的是不同形状的数组,那么原始数组中的元素又是如何堆叠成新数组的呢,stack实际上是利用了python广播机制先扩展为设定形状的数组再执行简单堆叠方法(简单堆叠函数vstack,hstack一般不改变原数组的维(轴)数,只对元素进行纵向或横向拼接)。
以np.stack((a,b),axis=0)
为例,数组a是array([1, 2, 3])
的形状是(3,),数组b是array([10, 10, 10])
的形状是(3,),由于axis=0,所以新增的维(轴)出现在第0维(轴)的位置得到形状假设为(x,3)的数组,而数组a和数组b是2个数组进行堆叠,则第0维(轴)上的形状数值x应当为2,所以最终的返回数组形状是(2,3)。注意,新增的维(轴)位置上的数值2并不是替换原来数组形状第0维(轴)位置上的数值3而是将原来第0维(轴)向后挤形成多一层级的第1维(轴)。
再考虑如何由两个形状(3,)的数组堆叠为最终形状(2,3)的数组,由于已知最终形状是(2,3),则原来(3,)的数组通过广播机制将形状扩展为(1,3),则a=array([1,2,3])将扩展为a'=array([[1,2,3]])
,同理b'=array([[10,10,10]])
,广播扩展后的2个数组再沿着axis=0的row方向堆叠(即按纵向堆叠row行数的vstack方法,array[[1,2,3]]沿0轴堆叠array[[10,10,10]])
,因此才得到了array([[1,2,3],[10,10,10]])
。
>>> np.vstack((np.array([[1,2,3]]),np.array([[10,10,10]])))
array([[ 1, 2, 3],
[10, 10, 10]])
同理np.stack((a,b),axis=1)
,最终返回的数组形状是(3,2),因此远来(3,)的数组扩展为(3,1),即a'=array([[1],[2],[3]])
和b’=array([[10],[10],[10]]),然后沿着axis=1的column方向堆叠(即按横向堆叠column列数的hstack方法,array([[1],[2],[3]]
沿1轴堆叠array[[10],[10],[10]])
,因此得到array([[1,10],[2,10],[3,10]])
>>> np.hstack((np.array([[1],[2],[3]]),np.array([[10],[10],[10]])))
array([[ 1, 10],
[ 2, 10],
[ 3, 10]])