split_func_call

ivy.split_func_call(func: Callable, inputs: List[Union[ivy.NativeArray, ivy.Container]], chunk_size: int, input_axes: Union[int, List[int]] = 0, output_axes: Optional[Union[int, List[int]]] = None, mean: bool = False) → List[Union[ivy.NativeArray, ivy.Container]][source]

Call a function by splitting its inputs along a given axis, and calling the function in chunks, rather than feeding the entire input array at once. This can be useful to reduce memory usage of the device the arrays are on.

Parameters
  • func (callable) – The function to be called.

  • inputs (sequence of arrays) – A list of inputs to pass into the function.

  • chunk_size (int) – The size of each of the chunks to be fed into the function.

  • input_axes (int or sequence of ints, optional) – The axes along which to split each of the inputs, before passing to the function. Default is 0.

  • output_axes (int or sequence of ints, optional) – The axes along which to concat each of the returned outputs. Default is same as fist input axis.

  • mean (bool, optional) – Whether to compute a weighted mean based on the return from each chunk. Default is False.

Returns

The return from the function, following input splitting and re-concattenation.


Supported Frameworks:

empty jax_logo empty tf_logo empty pytorch_logo empty mxnet_logo empty numpy_logo empty