(back to all posts)

Deep learning infrastructure

Development

Notebooks

pip install jupyter
pip install jupyterlab-vim
pip install jupyter-collaboration

python -m ipykernel install --user --name "ENV_NAME" --display-name "ENV_DISPLAY_NAME"

Software libraries

JAX with Equinox is my preferred deep learning framework. Because JAX uses XLA to compile down to accelerated kernels for the underlying hardware accelerator, I can focus on writing novel model architectures without worrying about how to scale up training (the only concern to keep in mind is the inherent parallelizability of the model, e.g. RNNs are fundamentally limited in the speed-up that can be achieved because they use the prior output to process the next input token).

To use published datasets, I use the datasets library from HuggingFace.

Scripts

Local (small-scale) training

Large-scale training

Technical notes

Building (and installing) JAX for macOS Rosetta

  1. Comment out verify_mac_libraries_dont_reference_chkstack() in jaxlib/tools/build_wheel.py. See https://github.com/jax-ml/jax/issues/3867 for details.

  2. Then run the following command to build JAX. The important flags are --target_cpu_features native and --bazel_options='--copt=-Wno-gnu-offsetof-extensions', which disable AVX support and unsupported C features on macOS.

python build/build.py build --wheels=jaxlib --verbose --python_version=3.11 --target_cpu_features native --bazel_options='--copt=-Wno-gnu-offsetof-extensions'
  1. Install the corresponding version of jax using pip:
pip install jax==JAXLIB_VERSION
# e.g. pip install jax==0.4.26
  1. Install jaxlib using the compiled wheel.
pip install WHEEL_FILE
# e.g. pip install jaxlib-0.4.36.dev20241214-cp311-cp311-macosx_10_14_x86_64.whl

Using the HuggingFace datasets library