evorl.distributed.sharding

Module Contents

Functions

shmap_map

Sequential execution on different gpu.

shmap_vmap

Vmap on different gpu.

tree_device_put

Pytree version of jax.device_put.

API

evorl.distributed.sharding.shmap_map(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]

Sequential execution on different gpu.

Parameters:
  • fn – function to be executed, only positional arguments are supported

  • sharding – JAX sharding object

Returns:

A wrapped function.

evorl.distributed.sharding.shmap_vmap(fn: collections.abc.Callable, mesh, in_specs, out_specs, **kwargs)[source]

Vmap on different gpu.

evorl.distributed.sharding.tree_device_put(tree: chex.ArrayTree, device_or_sharding)[source]

Pytree version of jax.device_put.