split_func_call_across_gpus(func: Callable, inputs: Iterable[Union[ivy.Array, ivy.NativeArray, ivy.Container]], dev_strs: Union[int, Iterable[int], Iterable[str]], input_axes: Optional[Union[int, Iterable[int]]] = None, output_axes: Optional[Union[int, Iterable[int]]] = None, concat_output: bool = False) → Iterable[Union[ivy.Array, ivy.NativeArray, ivy.Container]]¶
Call a function by splitting its inputs along a given axis, and calling each chunk on a different device.
func (callable) – The function to be called.
inputs (sequence of arrays or containers) – A list of inputs to pass into the function.
dev_strs (int, sequence of ints or sequence of strs) – The gpu device strings, in the format “gpu:idx”.
input_axes (int or sequence of ints, optional) – The axes along which to split each of the inputs, before passing to the function. Default is 0.
output_axes (int or sequence of ints, optional) – The axes along which to concat each of the returned outputs. Default is same as fist input axis.
concat_output (bool, optional) – Whether to concatenate each return values into a single array. Default is False.
The return from the function, following input splitting and re-concattenation across devices.