NodeMapper

class ivy.NodeMapper(fn, ret_fn, queue_class, worker_class, node_strs, timeout=None, constant=None, unique=None)[source]

Bases: abc.ABC

__init__(fn, ret_fn, queue_class, worker_class, node_strs, timeout=None, constant=None, unique=None)[source]

Node Mapper base class.

Parameters
  • fn (callable) – The function which the node mapper parallelises across nodes.

  • ret_fn (callable) – The function which receives the ivy.MultiNodeIter as input, and produces a single node output.

  • queue_class (class) – The class to use for creating queues.

  • worker_class (class) – The class to use for creating parallel workers.

  • node_strs (sequence of str) – A list of nodes on which to parallelise the function.

  • timeout (float, optional) – The timeout for getting items from the queues. Default is global.

  • constant (dict of any, optional) – A dict of keyword arguments which are the same for each process. Default is None.

  • unique (dict of iterables of any, optional) – A dict of keyword argument sequences which are unique for each process. Default is None.

map(used_node_strs=None, split_factors=None, **kwargs)[source]

Map the function fn to each of the MultiNode args and kwargs, running each function in parallel with CUDA-safe multiprocessing.

Parameters
  • used_node_strs (sequence of str, optional) – The nodes used in the current mapping pass. Default is all node_strs.

  • split_factors (dict of floats, optional) – The updated split factors 0 < sf < 1 for each node. Default is None.

  • kwargs (dict of any) – The MultiNode keyword arguments to map the function to.

Returns

The results of the function, returned as a MultiNode instance.


Supported Frameworks:

empty jax_logo empty tf_logo empty pytorch_logo empty mxnet_logo empty numpy_logo empty