evorl.distributed.sharding¶
Module Contents¶
Functions¶
Sequential execution on different gpu. |
|
Vmap on different gpu. |
|
Pytree version of |
API¶
- evorl.distributed.sharding.shmap_map(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]¶
Sequential execution on different gpu.
- Parameters:
fn – function to be executed, only positional arguments are supported
sharding – JAX sharding object
- Returns:
A wrapped function.