NumPy apply_over_axes()

The apply_over_axes() method allows you to apply a function repeatedly over multiple axes.

import numpy as np

# create a 3D array
arr = np.array([
    [[1, 2, 3],
     [4, 5, 6]],
    
    [[7, 8, 9],
     [10, 11, 12]]
])

# define a function to compute the column-wise sum
def col_sum(x, axis=0):
    # compute the sum along the specified axis
    return np.sum(x, axis=axis)

# apply col_sum over the first and third axes result = np.apply_over_axes(col_sum, arr, axes=(0, 2))
print(result) ''' Output: [[[ 8] [10] [12]] [[14] [16] [18]]] '''

apply_over_axes() Syntax

The syntax of apply_over_axes() is:

numpy.apply_over_axes(func, array, axis)

apply_over_axes() Arguments

The apply_over_axes() method takes the following arguments:

  • func - the function to apply
  • axis - the axis along which the functions are applied
  • array - the input array

Note: The func should take two arguments, an input array and axis.

apply_over_axes() Return Value

The apply_over_axes() method returns the resultant array with functions applied.


Example 1: Apply a Function Along Multiple Axes

import numpy as np

# create a 3D array
arr = np.arange(8).reshape(2, 2, 2)
print('Original Array:\n', arr)

# sum the array on axes (0 and 1) # adds the elements with same value at axis = 2 result = np.apply_over_axes(np.sum, arr, axes=(0, 1))
print('Sum along axes (0, 1):\n',result)
# sum the array on axes (0 and 2) # adds the elements with same value at axis = 1 result = np.apply_over_axes(np.sum, arr, axes=(0, 2))
print('Sum along axes (0, 2):\n',result)

Output

Original Array:
[[[0 1]
  [2 3]]

 [[4 5]
  [6 7]]]
Sum along axes (0, 1):
 [[[12 16]]]
Sum along axes (0, 2):
 [[[10]
  [18]]]

Example 2: Apply a lambda Function in an Array

We can return an array of values from the function.

import numpy as np

# create a 2D array
arr = np.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

# apply the lambda function to compute the sum of an array along a specific axis # compute the sum along the rows (axis=1) of the 2D array result = np.apply_over_axes(lambda arr, axis: np.sum(arr, axis=axis), arr, axes=(1))
print(result)

Output

[[ 6]
 [15]
 [24]]

Recommended Reading

Numpy apply_along_axis()