(back to all posts)

Deep learning infrastructure

Development

Notebooks

pipx install notebook --include-deps # notebook 7.3.2
pipx inject notebook jupyterlab-vim
pipx inject notebook jupyter-collaboration

# for new projects, create a virtual environment using the desired
python -m ipykernel install --user --name "ENV_NAME" --display-name "ENV_DISPLAY_NAME"

# set up configuration for jupyter notebooks
jupyter notebook --generate-config
vim ~/.jupyter/jupyter_notebook_config.py
c.ServerApp.password = ''
c.ServerApp.token = ''

Software libraries

JAX with Equinox is my preferred deep learning stack. Because JAX uses XLA to compile models into accelerated kernels for the underlying accelerator, I can focus on designing model architectures without worrying about how to scale up training. JAX code is device-agnostic, making it easy to develop and perform small scale training locally, while seamlessly scaling up training on the cloud.

To use published datasets, I use the datasets library from HuggingFace.

Scripts

Local (small-scale) training

Large-scale training

Technical notes

Building (and installing) JAX for macOS Rosetta

  1. Comment out verify_mac_libraries_dont_reference_chkstack() in jaxlib/tools/build_wheel.py. See https://github.com/jax-ml/jax/issues/3867 for details.

  2. Then run the following command to build JAX. The important flags are --target_cpu_features native and --bazel_options='--copt=-Wno-gnu-offsetof-extensions', which disable AVX support and unsupported C features on macOS.

python build/build.py build --wheels=jaxlib --verbose --python_version=3.11 --target_cpu_features native --bazel_options='--copt=-Wno-gnu-offsetof-extensions'
  1. Install the corresponding version of jax and equinox using pip:
pip install jax==JAXLIB_VERSION
pip install equinox==EQUINOX_COMPATIBLE_VERSION
# e.g. pip install jax==0.4.26 equinox==0.11.10
  1. Install jaxlib using the compiled wheel.
pip install WHEEL_FILE
# e.g. pip install jaxlib-0.4.36.dev20241214-cp311-cp311-macosx_10_14_x86_64.whl

Using the HuggingFace datasets library