Parallelization across devices↩
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
rng_key = jax.random.PRNGKey(1234)
v1 = jax.random.normal(rng_key, shape=(10_000_000, 3))
v2 = jax.random.normal(rng_key, shape=(10_000_000, 3))
def dot(v1, v2):
return jnp.vdot(v1, v2)
Collective operations
After running a vmap
'ed or pmap
'ed operation, one of the output axes will be the batched/parallelized axis. If this axis is named (using the axis_name
option when initially creating the vmap
'ed or pmap
'ed function), then collective operations (e.g. sum, max, min, mean, permute, shuffle, swapaxes) can reference said axis to aggregate (i.e. collect) the distributed results.
Distributed computing patterns
- Broadcast: This function is used to distribute data from one processing unit to all processing units.
- In JAX, using None in in_axes indicates that an argument doesn't have an extra axis to map over and should be broadcasted. You can also use a special
static_broadcasted_argnums
parameter to specify which positional arguments to treat as static (compile-time constant) and broadcast to all devices.
- Scatter: This pattern is used to distribute data from one processing unit to all the processing units, but unlike Broadcast, which sends the same message to each processing unit, Scatter splits the message and delivers one part of it to each processing unit.
- This is exactly what is done when we use mapping over an axis with
pmap()
or vmap()
.
- Reduce: This pattern is inverse to Broadcast and is used to collect data from various computing devices and combine them into global results with some given function (e.g., summation).
- All-reduce: This is a special case of the Reduce pattern when the result of the Reduce operation must be distributed across all processing units. The
psum()
/pmax()
/pmin()
/pmean()
functions perform all-reduce operations.
- Gather: This pattern is used to store data from all processing units on a single processing unit.
- All-gather: This pattern is used to collect data from all processing units and to store the collected data on all processing units.
- All-to-all: This pattern, also called a total exchange, is used when each processing unit sends a message to every other processing unit.