split_func_call_across_gpus

ivy.split_func_call_across_gpus(func: Callable, inputs: Iterable[Union[ivy.Array, ivy.NativeArray, ivy.Container]], dev_strs: Union[int, Iterable[int], Iterable[str]], input_axes: Optional[Union[int, Iterable[int]]] = None, output_axes: Optional[Union[int, Iterable[int]]] = None, concat_output: bool = False) → Iterable[Union[ivy.Array, ivy.NativeArray, ivy.Container]][source]

Call a function by splitting its inputs along a given axis, and calling each chunk on a different device.

Parameters
  • func (callable) – The function to be called.

  • inputs (sequence of arrays or containers) – A list of inputs to pass into the function.

  • dev_strs (int, sequence of ints or sequence of strs) – The gpu device strings, in the format “gpu:idx”.

  • input_axes (int or sequence of ints, optional) – The axes along which to split each of the inputs, before passing to the function. Default is 0.

  • output_axes (int or sequence of ints, optional) – The axes along which to concat each of the returned outputs. Default is same as fist input axis.

  • concat_output (bool, optional) – Whether to concatenate each return values into a single array. Default is False.

Returns

The return from the function, following input splitting and re-concattenation across devices.


Supported Frameworks:

empty jax_logo empty tf_logo empty pytorch_logo empty mxnet_logo empty numpy_logo empty