pipx install notebook --include-deps # notebook 7.3.2
pipx inject notebook jupyterlab-vim
pipx inject notebook jupyter-collaboration
# for new projects, create a virtual environment using the desired
python -m ipykernel install --user --name "ENV_NAME" --display-name "ENV_DISPLAY_NAME"
# set up configuration for jupyter notebooks
jupyter notebook --generate-config
vim ~/.jupyter/jupyter_notebook_config.py
c.ServerApp.password = ''
c.ServerApp.token = ''
JAX with Equinox is my preferred deep learning stack. Because JAX uses XLA to compile models into accelerated kernels for the underlying accelerator, I can focus on designing model architectures without worrying about how to scale up training. JAX code is device-agnostic, making it easy to develop and perform small scale training locally, while seamlessly scaling up training on the cloud.
To use published datasets, I use the datasets
library from HuggingFace.
Comment out verify_mac_libraries_dont_reference_chkstack()
in jaxlib/tools/build_wheel.py
. See https://github.com/jax-ml/jax/issues/3867 for details.
Then run the following command to build JAX. The important flags are --target_cpu_features native
and --bazel_options='--copt=-Wno-gnu-offsetof-extensions'
, which disable AVX support and unsupported C features on macOS.
python build/build.py build --wheels=jaxlib --verbose --python_version=3.11 --target_cpu_features native --bazel_options='--copt=-Wno-gnu-offsetof-extensions'
jax
and equinox
using pip:pip install jax==JAXLIB_VERSION
pip install equinox==EQUINOX_COMPATIBLE_VERSION
# e.g. pip install jax==0.4.26 equinox==0.11.10
jaxlib
using the compiled wheel.pip install WHEEL_FILE
# e.g. pip install jaxlib-0.4.36.dev20241214-cp311-cp311-macosx_10_14_x86_64.whl
datasets
library