DevMapper(fn, ret_fn, queue_class, worker_class, dev_strs, timeout=None, constant=None, unique=None)¶
__init__(fn, ret_fn, queue_class, worker_class, dev_strs, timeout=None, constant=None, unique=None)¶
Device Mapper base class.
fn (callable) – The function which the device mapper parallelises across devices.
ret_fn (callable) – The function which receives the ivy.MultiDevIter as input, and produces a single device output.
queue_class (class) – The class to use for creating queues.
worker_class (class) – The class to use for creating parallel workers.
dev_strs (sequence of str) – A list of devices on which to parallelise the function.
timeout (float, optional) – The timeout for getting items from the queues. Default is global.
constant (dict of any, optional) – A dict of keyword arguments which are the same for each process. Default is None.
unique (dict of iterables of any, optional) – A dict of keyword argument sequences which are unique for each process. Default is None.
map(used_dev_strs=None, split_factors=None, **kwargs)¶
Map the function fn to each of the MultiDevice args and kwargs, running each function in parallel with CUDA-safe multiprocessing.
used_dev_strs (sequence of str, optional) – The devices used in the current mapping pass. Default is all dev_strs.
split_factors (dict of floats, optional) – The updated split factors 0 < sf < 1 for each device. Default is None.
kwargs (dict of any) – The MutliDevice keyword arguments to map the function to.
The results of the function, returned as a MultiDevice instance.