split_func_call

ivy.split_func_call(func: Callable, inputs: Iterable[Union[ivy.Array, ivy.NativeArray, ivy.Container]], mode: str, max_chunk_size: Optional[int] = None, chunk_size: Optional[int] = None, input_axes: Union[int, Iterable[int]] = 0, output_axes: Optional[Union[int, Iterable[int]]] = None) → Iterable[Union[ivy.Array, ivy.NativeArray, ivy.Container]][source]

Call a function by splitting its inputs along a given axis, and calling the function in chunks, rather than feeding the entire input array at once. This can be useful to reduce memory usage of the device the arrays are on. :param func: The function to be called. :type func: callable :param inputs: A list of inputs to pass into the function. :type inputs: sequence of arrays :param mode: The mode by which to unify the return values, must be one of [ concat | mean | sum ] :type mode: str :param max_chunk_size: The maximum size of each of the chunks to be fed into the function. :type max_chunk_size: int :param chunk_size: The size of each of the chunks to be fed into the function. Specifying this arg overwrites the

global split factor. Default is None.

Parameters
  • input_axes (int or sequence of ints, optional) – The axes along which to split each of the inputs, before passing to the function. Default is 0.

  • output_axes (int or sequence of ints, optional) – The axes along which to concat each of the returned outputs. Default is same as fist input axis.

Returns

The return from the function, following input splitting and re-concattenation.


Supported Frameworks:

empty jax_logo empty tf_logo empty pytorch_logo empty mxnet_logo empty numpy_logo empty