Summary of the usage of the axis of Python Numpy array

I have been learning [numpy] myself for a long time, but I have never understood the usage of axis in its array. I often try it myself to see what happens when axis=0 or axis=1, and then use it in the code, for example:

First import numpy and create an array

>>> import numpy as np
>>> a = np.array([[1,2],[10,20]])
>>> a
array([[ 1,  2],
       [10, 20]])

Try to see the output of the mean when axis=0 :

>>> a.mean(axis=0)
array([  5.5,  11. ])

Average output when aixs=1 :

>>> a.mean(axis=1)
array([  1.5,  15. ])

It seems that the rule is that when [axis] =0, the average is calculated by column and when axis=1, the average is calculated by row . But this method is difficult to apply to high-dimensional arrays:

>>> b = np.array([[[1,2,3],[4,5,6],[7,8,9]]])
>>> b
array([[[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]])
>>> b.shape
(1, 3, 3)
>>> b.mean(axis=0)
array([[ 1.,  2.,  3.],
       [ 4.,  5.,  6.],
       [ 7.,  8.,  9.]])
>>> b.mean(axis=1)
array([[ 4.,  5.,  6.]])

It can be seen that when the shape of the [array](1,3,3) is axis=1, it is not averaged by row, and when axis=0, the result is (3,3)an array of shape whose value is itself.
From the above we can conclude:

  • The value of axis is related to the shape of the array.
  • If axis=0, the average is calculated according to the outermost array; if axis=1, the average is calculated according to the second-to-last array, and so on.
  • Then we can deduce: Since the above b array is a three-dimensional shape (1,3,3), then there is axis=2, and the result is the average value of each of the innermost 3 rows:

>>> b.mean(axis=2)
array([[ 2.,  5.,  8.]])

But since b is only three-dimensional (python starts indexing the shape from 0), when we use axis=3, there will be an error:

>>> b.mean(axis=3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/jinjunjie/anaconda/lib/python2.7/site-packages/numpy/core/", line 56, in _mean
    rcount = _count_reduce_items(arr, axis)
  File "/Users/jinjunjie/anaconda/lib/python2.7/site-packages/numpy/core/", line 50, in _count_reduce_items
    items *= arr.shape[ax]
IndexError: tuple index out of range

Students who can speak English should have a deeper understanding of this and this .

Leave a Comment

Your email address will not be published. Required fields are marked *