pip install jupyter
pip install jupyterlab-vim
pip install jupyter-collaboration
python -m ipykernel install --user --name "ENV_NAME" --display-name "ENV_DISPLAY_NAME"
JAX with Equinox is my preferred deep learning framework. Because JAX uses XLA to compile down to accelerated kernels for the underlying hardware accelerator, I can focus on writing novel model architectures without worrying about how to scale up training (the only concern to keep in mind is the inherent parallelizability of the model, e.g. RNNs are fundamentally limited in the speed-up that can be achieved because they use the prior output to process the next input token).
To use published datasets, I use the datasets
library from HuggingFace.
Comment out verify_mac_libraries_dont_reference_chkstack()
in jaxlib/tools/build_wheel.py
. See https://github.com/jax-ml/jax/issues/3867 for details.
Then run the following command to build JAX. The important flags are --target_cpu_features native
and --bazel_options='--copt=-Wno-gnu-offsetof-extensions'
, which disable AVX support and unsupported C features on macOS.
python build/build.py build --wheels=jaxlib --verbose --python_version=3.11 --target_cpu_features native --bazel_options='--copt=-Wno-gnu-offsetof-extensions'
jax
using pip:pip install jax==JAXLIB_VERSION
# e.g. pip install jax==0.4.26
jaxlib
using the compiled wheel.pip install WHEEL_FILE
# e.g. pip install jaxlib-0.4.36.dev20241214-cp311-cp311-macosx_10_14_x86_64.whl
datasets
library