Parallelization across devices

import os
# set XLA_FLAGS before importing JAX
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

import jax

rng_key = jax.random.PRNGKey(1234)
v1 = jax.random.normal(rng_key, shape=(10_000_000, 3))
v2 = jax.random.normal(rng_key, shape=(10_000_000, 3))

def dot(v1, v2):
    return jnp.vdot(v1, v2)

Collective operations

After running a vmap'ed or pmap'ed operation, one of the output axes will be the batched/parallelized axis. If this axis is named (using the axis_name option when initially creating the vmap'ed or pmap'ed function), then collective operations (e.g. sum, max, min, mean, permute, shuffle, swapaxes) can reference said axis to aggregate (i.e. collect) the distributed results.

Distributed computing patterns

  • Broadcast: This function is used to distribute data from one processing unit to all processing units.
    • In JAX, using None in in_axes indicates that an argument doesn't have an extra axis to map over and should be broadcasted. You can also use a special static_broadcasted_argnums parameter to specify which positional arguments to treat as static (compile-time constant) and broadcast to all devices.
  • Scatter: This pattern is used to distribute data from one processing unit to all the processing units, but unlike Broadcast, which sends the same message to each processing unit, Scatter splits the message and delivers one part of it to each processing unit.
    • This is exactly what is done when we use mapping over an axis with pmap() or vmap().
  • Reduce: This pattern is inverse to Broadcast and is used to collect data from various computing devices and combine them into global results with some given function (e.g., summation).
    • All-reduce: This is a special case of the Reduce pattern when the result of the Reduce operation must be distributed across all processing units. The psum()/pmax()/pmin()/pmean() functions perform all-reduce operations.
  • Gather: This pattern is used to store data from all processing units on a single processing unit.
    • All-gather: This pattern is used to collect data from all processing units and to store the collected data on all processing units.
  • All-to-all: This pattern, also called a total exchange, is used when each processing unit sends a message to every other processing unit.