The apply_over_axes()
method allows you to apply a function repeatedly over multiple axes.
import numpy as np
# create a 3D array
arr = np.array([
[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]
])
# define a function to compute the column-wise sum
def col_sum(x, axis=0):
# compute the sum along the specified axis
return np.sum(x, axis=axis)
# apply col_sum over the first and third axes
result = np.apply_over_axes(col_sum, arr, axes=(0, 2))
print(result)
'''
Output:
[[[ 8]
[10]
[12]]
[[14]
[16]
[18]]]
'''
apply_over_axes() Syntax
The syntax of apply_over_axes()
is:
numpy.apply_over_axes(func, array, axis)
apply_over_axes() Arguments
The apply_over_axes()
method takes the following arguments:
func
- the function to applyaxis
- the axis along which the functions are appliedarray
- the input array
Note: The func
should take two arguments, an input array and axis.
apply_over_axes() Return Value
The apply_over_axes()
method returns the resultant array with functions applied.
Example 1: Apply a Function Along Multiple Axes
import numpy as np
# create a 3D array
arr = np.arange(8).reshape(2, 2, 2)
print('Original Array:\n', arr)
# sum the array on axes (0 and 1)
# adds the elements with same value at axis = 2
result = np.apply_over_axes(np.sum, arr, axes=(0, 1))
print('Sum along axes (0, 1):\n',result)
# sum the array on axes (0 and 2)
# adds the elements with same value at axis = 1
result = np.apply_over_axes(np.sum, arr, axes=(0, 2))
print('Sum along axes (0, 2):\n',result)
Output
Original Array: [[[0 1] [2 3]] [[4 5] [6 7]]] Sum along axes (0, 1): [[[12 16]]] Sum along axes (0, 2): [[[10] [18]]]
Example 2: Apply a lambda Function in an Array
We can return an array of values from the function.
import numpy as np
# create a 2D array
arr = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# apply the lambda function to compute the sum of an array along a specific axis
# compute the sum along the rows (axis=1) of the 2D array
result = np.apply_over_axes(lambda arr, axis: np.sum(arr, axis=axis), arr, axes=(1))
print(result)
Output
[[ 6] [15] [24]]