google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Provide wheels for macOS ARM

ericmjl opened this issue · comments

Hi all,

I was digging around to see what might need to happen to allow JAX to work on Apple Silicon. Knowing that JAX gets compiled to XLA, my guess here is that XLA would need to be made Apple Silicon-compatible first before JAX could run on it. May I ask, do you all know if there are plans on the XLA team to make that happen, or is it being ignored completely? (Knowing the answer can help me make some decisions on how I should set up my development environment mostly.)

Cheers,
Eric

Check out this post #5084 by @hawkinsp (cc @rxwei)

Targeting the M1's ARM CPU shouldn't be difficult.

XLA already supports AArch64 and has done for a long time. I suspect that tensorflow/tensorflow#45404 already did most of the work to adapt XLA to build on the M1 and all that is left is a few small changes to the .bazelrc file that JAX's build.py script generates, analogous to the changes in that TF PR.

That said, I don't have access to any M1 hardware, so this is in the "contributions welcome" category.

Targeting the GPU or the Neural Engine is likely a lot more difficult. For GPU, one would probably need to target Metal (probably doable, but not trivial), and I'm unsure how we could target the Neural Engine at this time.

Hi! I have tried to get jaxlib working now on my apple m1, and have managed to build it with some minor changes according to tensorflow/tensorflow#45404 (branch on https://github.com/jotsif/jax/tree/jax_for_darwin_arm64).

However when loading jaxlib in python3 (3.9.2) I get the import error ImportError: dlopen(/opt/homebrew/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: _LLVMInitializeAArch64AsmPrinter

Seems to be a known issue from comment here tensorflow/tensorflow#45404 (comment) and here elixir-nx/nx#217 (comment).

I am investigating further but if anyone has any tip that would be welcome.

Update: this TF PR will probably fix this issue tensorflow/tensorflow#47594

@jotsif thanks for that information, that's really helpful!

Confirmed that it works 🎉 . Built with Bazel master, and https://github.com/freedomtan/tensorflow/tree/bazel_native_build_on_m1

>>> import platform
>>> platform.uname()
uname_result(system='Darwin', node='Josefs-MBP-2.lan', release='20.3.0', version='Darwin Kernel Version 20.3.0: Thu Jan 21 00:06:51 PST 2021; root:xnu-7195.81.3~1/RELEASE_ARM64_T8101', machine='arm64')
>>> from jax.lib import xla_client as xc
>>> xops = xc.ops
>>> c = xc.XlaBuilder("simple_scalar")
>>> param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())
>>> x = xops.Parameter(c, 0, param_shape)
>>> y = xops.Sin(x)
>>> computation = c.Build()
>>> cpu_backend = xc.get_local_backend("cpu")
2021-03-07 09:26:16.684549: W external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
>>> compiled_computation = cpu_backend.compile(computation)
>>> host_input = np.array(3.0, dtype=np.float32)
>>> device_input = cpu_backend.buffer_from_pyval(host_input)
>>> device_out = compiled_computation.execute([device_input ,])
>>> device_out[0].to_py()
array(0.14112, dtype=float32)

@mattjj @hawkinsp If you want a PR I would be happy to create one, but maybe it makes more sense to wait until bazel has released a working native arm64 build and tensorflow have the necessary code in master.

https://github.com/freedomtan/tensorflow/tree/bazel_native_build_on_m1

Congrats! Do you mind sharing a jax/jaxlib wheel for M1?

Hi @jotsif -

Bazel 4.1 works natively with arm64 - bazelbuild/bazel#13099 and TF has necessary code in master.

If you have an example branch, i'm happy to help towards a PR now!

@akbir I think we still need to wait for Bazel to actually release 4.1. But that should be soon I think! At that point we can probably just bump the Bazel dependency to 4.1 and hopefully everything should work on Mac ARM.

We can look into releasing Mac ARM wheels as well, although we don't yet have a way to test them (we personally do not have Mac ARM hardware yet), which gives me some pause.

I believe now that at head jaxlib will build from source on a Mac M1 and pretty much everything works (*). We still don't have a great way to provide pre-built wheels yet, but hopefully this is enough to unblock everyone!

(*) It's a bit annoying to install jax still, mostly because there aren't any prebuilt scipy wheels, so you'll have to build scipy yourself to build jaxlib or use jax. I followed these instructions scipy/scipy#13409 (comment) which worked for me.

Hi @hawkinsp - tried following this but hit the following error when running build.py (included full logs at the bottom).

ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis of target '@local_config_cc//:toolchain' failed

Unsure why toolchain isn't working an error is discussed here: bazelbuild/bazel#13099 (comment).

I've also run build against bazel 4.1.0.rc5 and still get the same error.

Full error logs:

Bazel binary path: ./bazel-4.1.0rc4-darwin-arm64
Python binary path: /Users/akbirkhan/jax/venv/bin/python
Python version: 3.9
MKL-DNN enabled: yes
Target CPU features: release
CUDA enabled: no
TPU enabled: no
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
./bazel-4.1.0rc4-darwin-arm64 run --verbose_failures=true --config=short_logs --config=mkl_open_source_only :build_wheel -- --output_path=/Users/akbirkhan/jax/dist
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /Users/akbirkhan/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /Users/akbirkhan/jax/.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=/Users/akbirkhan/jax/venv/bin/python --action_env=PYENV_ROOT --python_path=/Users/akbirkhan/jax/venv/bin/python --repo_env TF_NEED_CUDA=0 --action_env TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 --distinct_host_configuration=false -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config
INFO: Found applicable config definition build:short_logs in file /Users/akbirkhan/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file /Users/akbirkhan/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:macos in file /Users/akbirkhan/jax/.bazelrc: --config=posix
INFO: Found applicable config definition build:posix in file /Users/akbirkhan/jax/.bazelrc: --copt=-Wno-sign-compare --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
Loading: 
Loading: 0 packages loaded
WARNING: Download from http://mirror.tensorflow.org/github.com/tensorflow/runtime/archive/3f4cd5e8a34eb2179537b8f71b1484bb0d26701f.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException GET returned 404 Not Found
DEBUG: /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/tf_runtime/third_party/cuda/dependencies.bzl:51:10: The following command will download NVIDIA proprietary software. By using the software you agree to comply with the terms of the license agreement that accompanies the software. If you do not agree to the terms of the license agreement, do not use the software.
Analyzing: target //build:build_wheel (0 packages loaded, 0 targets configured)
ERROR: /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/local_config_cc/BUILD:48:19: in cc_toolchain_suite rule @local_config_cc//:toolchain: cc_toolchain_suite '@local_config_cc//:toolchain' does not contain a toolchain for cpu 'darwin_arm64'
DEBUG: Rule 'io_bazel_rules_docker' indicated that a canonical reproducible form can be obtained by modifying arguments shallow_since = "1556410077 -0400"
DEBUG: Repository io_bazel_rules_docker instantiated at:
  /Users/akbirkhan/jax/WORKSPACE:34:10: in <toplevel>
  /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/org_tensorflow/tensorflow/workspace0.bzl:108:34: in workspace
  /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/bazel_toolchains/repositories/repositories.bzl:37:23: in repositories
Repository rule git_repository defined at:
  /private/var/tmp/_bazel_akbirkhan/30d03974cb2cfe2f871935b707b3981a/external/bazel_tools/tools/build_defs/repo/git.bzl:199:33: in <toplevel>
ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis of target '@local_config_cc//:toolchain' failed
INFO: Elapsed time: 0.155s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded, 0 targets configured)
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully (0 packages loaded, 0 targets configured),
subprocess.CalledProcessError: Command '['./bazel-4.1.0rc4-darwin-arm64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/Users/akbirkhan/jax/dist']' returned non-zero exit status 1.

Getting same error as @akbir .

ERROR: /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/local_config_cc/BUILD:48:19: in cc_toolchain_suite rule @local_config_cc//:toolchain: cc_toolchain_suite '@local_config_cc//:toolchain' does not contain a toolchain for cpu 'darwin_arm64'
INFO: Repository com_google_absl instantiated at:
  /Users/noah/harvard/2021/network_repair/jax/WORKSPACE:30:10: in <toplevel>
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:1090:28: in workspace
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:56:9: in _initialize_third_party
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/absl/workspace.bzl:12:20: in repo
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:112:21: in tf_http_archive
Repository rule _tf_http_archive defined at:
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:65:35: in <toplevel>
Analyzing: target //build:build_wheel (35 packages loaded, 264 targets configured)
INFO: Repository cython instantiated at:
  /Users/noah/harvard/2021/network_repair/jax/WORKSPACE:30:10: in <toplevel>
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:1097:21: in workspace
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/tensorflow/workspace2.bzl:845:20: in _tf_repositories
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:112:21: in tf_http_archive
Repository rule _tf_http_archive defined at:
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/org_tensorflow/third_party/repo.bzl:65:35: in <toplevel>
INFO: Repository pocketfft instantiated at:
  /Users/noah/harvard/2021/network_repair/jax/WORKSPACE:24:10: in <toplevel>
  /Users/noah/harvard/2021/network_repair/jax/third_party/pocketfft/workspace.bzl:20:17: in repo
Repository rule http_archive defined at:
  /private/var/tmp/_bazel_noah/2c44e3012980c1c9cbcf3acb09acecec/external/bazel_tools/tools/build_defs/repo/http.bzl:336:31: in <toplevel>
ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis of target '@local_config_cc//:toolchain' failed
INFO: Elapsed time: 10.071s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (35 packages loaded, 264 targets configured)
ERROR: Build failed. Not running target

FYI for those looking for a quick and dirty workaround, you can install jax and jaxlib using pip a miniconda environment running in Rosetta 2. The most recent versions of jax and jaxlib don't work (giving an error like "zsh: illegal hardware instruction"), So I ended up using jax == 0.2.10 and jaxlib==0.1.60.

@akbir Do you have a working XCode installation including the command line tools?
https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source

I don't think the error you are seeing is related to the Bazel version. We're currently pinning Bazel 4.1.0rc4 because that was the newest version last week. If you like you can try a different Bazel version, but 4.1.0rc4 worked for me. The easiest way to do that is to install Bazel yourself and pass --bazel_path=/somewhere/bazel to the build.py command line.

@Noahyt jaxlib 0.1.62 and newer on x86 use AVX, which Rosetta does not support (https://github.com/google/jax/blob/master/CHANGELOG.md#jaxlib-0162-march-9-2021). All recent x86 CPUs support AVX and have for a long time. We don't intend to ship wheels without AVX, although if you like you can build a jaxlib from source that does not require AVX. But I don't think we should worry about that too much, since we want a native ARM version, anyway.

Got this to finally build!! Thank you @hawkinsp

For others, I installed Xcode (not just command-line interface) and Bazel.

Also updated .bazelversion to 4.1.0 (should this be updated in the repo?)

sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license
bazel sync --configure

python build.py --bazel_path=/somewhere/bazel

@akbir If you want to send a PR that updates the Bazel version for Mac ARM, that sounds great! The version of bazel is chosen here: https://cs.opensource.google/jax/jax/+/master:build/build.py;drc=dacf31f2020175181014745cdabc240a10031227;l=119

would love to!

quick question - why does jax also specify the version here: https://github.com/google/jax/blob/master/.bazelversion ?

@akbir If I remember correctly, that's for folks using Bazelisk (https://github.com/bazelbuild/bazelisk). I don't know if it's possible to specify a separate version for Mac arm via Bazelisk.

We can probably upgrade to 4.1.0 for all platforms, but let's not do that right away. So fixing build.py is probably enough.

Going through the comments it looks like we've done all we could on the JAX side, and the remaining difficulties are in installing SciPy. Is this ready to be closed, or are there any outstanding issues?

Outstanding issue - is having wheels for darwin-arm64, which is blocked by scipy.

Not sure how Jax organises issues. But that's the gist^

It seems from the scipy issue that they are currently targeting 1.7.1 for Mac ARM wheels.

The remaining action item on the JAX side is then (eventually) to build Mac M1 wheels as part of our release process.

Hello, @hawkinsp. Is there any rough timeline onto when the next release is?
Update : I was able to build the wheel.

#7254 allows us to cross-compile Mac ARM wheels on an x86 Mac machine. Unfortunately we have no way to test them, not having any Mac ARM hardware of our own or any way to emulate it, so while we can certainly release the resulting wheels as a service to the community, there will be no guarantees that they will work in any given release. We'd have to rely on the community to try them out and report bugs.

I have manually verified at least the current version of jaxlib works on a (borrowed) Macbook M1.

Because the wheels will be untested by us, for now our plan is to add a warning when JAX is imported on a Mac M1 machine that says something along the lines of "JAX on Mac ARM machines is experimental and community supported, see this github issue in the event of any problems" until such time as we can perform testing either on our own machines or through Github CI.

Hi, I'm using a device with the M1 chip and have been following this thread very attentively although my understanding is very basic. I've been relentlessly trying to import Jax for a few days now, most recently I switched to @hawkinsp 's PR #7254 and attempted to build jaxlib with python build/build.py (using python 3.8.8, and having uninstalled bazel) however it didn't work, this is the output :

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Starting local Bazel server and connecting to it...
Bazel binary path: ./bazel-3.7.2-darwin-x86_64
Python binary path: /opt/homebrew/anaconda3/bin/python
Python version: 3.8
NumPy version: 1.19.5
SciPy version: 1.6.2
MKL-DNN enabled: yes
Target CPU: x86_64
Target CPU features: release
CUDA enabled: no
TPU enabled: no
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
./bazel-3.7.2-darwin-x86_64 run --verbose_failures=true --config=short_logs --config=avx_posix --config=mkl_open_source_only :build_wheel -- --output_path=/Users/steve/Documents/code/jax/dist --cpu=x86_64
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /Users/steve/Documents/code/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /Users/steve/Documents/code/jax/.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=/opt/homebrew/anaconda3/bin/python --action_env=PYENV_ROOT --python_path=/opt/homebrew/anaconda3/bin/python --repo_env TF_NEED_CUDA=0 --action_env TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config --distinct_host_configuration=false
INFO: Found applicable config definition build:short_logs in file /Users/steve/Documents/code/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:avx_posix in file /Users/steve/Documents/code/jax/.bazelrc: --copt=-mavx --host_copt=-mavx
INFO: Found applicable config definition build:mkl_open_source_only in file /Users/steve/Documents/code/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:macos in file /Users/steve/Documents/code/jax/.bazelrc: --config=posix
INFO: Found applicable config definition build:posix in file /Users/steve/Documents/code/jax/.bazelrc: --copt=-Wno-sign-compare --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
Loading:
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Analyzing: target //build:build_wheel (1 packages loaded, 0 targets configured)
Analyzing: target //build:build_wheel (12 packages loaded, 12 targets configured)
DEBUG: Rule 'io_bazel_rules_docker' indicated that a canonical reproducible form can be obtained by modifying arguments shallow_since = "1596824487 -0400"
DEBUG: Repository io_bazel_rules_docker instantiated at:
  /Users/steve/Documents/code/jax/WORKSPACE:34:10: in <toplevel>
  /private/var/tmp/_bazel_steve/c238ce85f8bfae0489a1104fd56ad209/external/org_tensorflow/tensorflow/workspace0.bzl:108:34: in workspace
  /private/var/tmp/_bazel_steve/c238ce85f8bfae0489a1104fd56ad209/external/bazel_toolchains/repositories/repositories.bzl:35:23: in repositories
Repository rule git_repository defined at:
  /private/var/tmp/_bazel_steve/c238ce85f8bfae0489a1104fd56ad209/external/bazel_tools/tools/build_defs/repo/git.bzl:199:33: in <toplevel>
Analyzing: target //build:build_wheel (14 packages loaded, 12 targets configured)
Analyzing: target //build:build_wheel (27 packages loaded, 144 targets configured)
Analyzing: target //build:build_wheel (42 packages loaded, 272 targets configured)
Analyzing: target //build:build_wheel (97 packages loaded, 1923 targets configured)
Analyzing: target //build:build_wheel (153 packages loaded, 4414 targets configured)
Analyzing: target //build:build_wheel (164 packages loaded, 7566 targets configured)
Analyzing: target //build:build_wheel (164 packages loaded, 7623 targets configured)
INFO: Analyzed target //build:build_wheel (171 packages loaded, 14292 targets configured).
INFO: Found 1 target...
[1 / 828] [Prepa] BazelWorkspaceStatusAction stable-status.txt
[73 / 1,034] Compiling com_google_protobuf/src/google/protobuf/compiler/cpp/cpp_file.cc; 2s local ... (8 actions, 7 running)
[110 / 1,034] Compiling com_google_protobuf/src/google/protobuf/compiler/objectivec/objectivec_helpers.cc; 2s local ... (8 actions, 7 running)
[181 / 1,270] Compiling com_github_grpc_grpc/src/compiler/python_generator.cc; 4s local ... (8 actions, 7 running)
[264 / 1,270] Compiling llvm-project/llvm/lib/Support/ItaniumManglingCanonicalizer.cpp; 0s local ... (8 actions, 7 running)
[333 / 1,281] Compiling llvm-project/mlir/tools/mlir-tblgen/OpFormatGen.cpp; 5s local ... (8 actions, 7 running)
[421 / 1,394] Compiling llvm-project/llvm/lib/MC/WasmObjectWriter.cpp; 2s local ... (8 actions, 7 running)
[506 / 1,445] Compiling llvm-project/llvm/lib/DebugInfo/CodeView/StringsAndChecksums.cpp; 1s local ... (8 actions, 7 running)
ERROR: /private/var/tmp/_bazel_steve/c238ce85f8bfae0489a1104fd56ad209/external/llvm-project/mlir/BUILD:155:18: TdGenerate external/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h.inc failed (Illegal instruction): mlir-tblgen failed: error executing command
  (cd /private/var/tmp/_bazel_steve/c238ce85f8bfae0489a1104fd56ad209/execroot/__main__ && \
  exec env - \
  bazel-out/darwin-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen --gen-attrdef-decls external/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.td -I external/llvm-project/mlir/include -I bazel-out/darwin-opt/bin/external/llvm-project/mlir/include -I external/llvm-project/ -I bazel-out/darwin-opt/bin/external/llvm-project/ -I external/llvm-project/mlir/include/mlir/IR -I bazel-out/darwin-opt/bin/external/llvm-project/mlir/include/mlir/IR -o bazel-out/darwin-opt/bin/external/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h.inc)
Execution platform: @local_execution_config_platform//:platform
Target //build:build_wheel failed to build
INFO: Elapsed time: 147.032s, Critical Path: 10.15s
INFO: 476 processes: 9 internal, 467 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
b''
Traceback (most recent call last):
  File "build/build.py", line 587, in <module>
    main()
  File "build/build.py", line 582, in main
    shell(command)
  File "build/build.py", line 52, in shell
    output = subprocess.check_output(cmd)
  File "/opt/homebrew/anaconda3/lib/python3.8/subprocess.py", line 415, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/opt/homebrew/anaconda3/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-3.7.2-darwin-x86_64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=avx_posix', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/Users/steve/Documents/code/jax/dist', '--cpu=x86_64']' returned non-zero exit status 1.

I've been following this guid for how to build jaxlib from source. Also some potential useful information (I have no idea): running platform.machine() in python returns "x86_64" which I'm confused about because I thought M1 uses ARM.

Unfortunately the old version of JAX that does import without complaint doesn't support differentiating through discrete Fourier transforms jnp.fft.fft, that is jax==0.2.10 and jaxlib==0.1.60

Hope this is helpful! I would really appreciate any guidance / help whatsoever. Thank you @hawkinsp for working on this issue! It's impossible for me to get any work done without JAX

Hi @dcxSt!

So a couple of thing's I'd note, if you want to build Jax on your m1.

  1. Your python interpreter is using rosetta (it thinks its running on intel x86 as opposed to the m1 arm64). I'm not sure how you installed it but you'll need the native m1 version. This can be easily done using brew. I'm not well versed using anaconda but i'm sure theres a channel somewhere for it.

  2. You're also building using an old version of bazel - don't use the one provided in Jax - @hawkinsp mentions that here

Alternatively after #7268 - if you have the correct python interpreter i think pip install jax should do the trick (can check later if this helps).

@dcxSt @akbir is correct. You are building on a Python x86-64 interpreter. You have two choices:

  • install an ARM Python interpreter and use that (probably preferable).
  • keep using an x86-64 interpreter, but pass --target_cpu_features=default to build.py to build without AVX.

I recommend using a native version if you can!

Hallelujah it works! Thank you @akbir thank you @hawkinsp you are my saviours!

For others, here's what I did:
Installed python for arm64 (first uninstall anaconda with brew uninstall anaconda)

brew install --cask miniforge
conda init zsh
conda activate
conda install numpy scipy scikit-learn

Build jaxlib from source by cloning the jax repository, and install bazel 4

git clone https://github.com/google/jax 
brew install bazel

Then enter the jax directory cd jax and build jaxlib from source, this takes a while to run (~8 mins)

python build/build.py
pip install dist/*.whl  # installs jaxlib (includes XLA)

Then install jax with either pip install -e . or pip install jax
And it works!

Edit: I forgot to mention that I installed xcode before doing any of this, the whole thing not just command line tools.

Thanks to the above instruction, I have also managed to build jaxlib on M1.

But when importing jax, I get the following strange error (importing jaxlib seems to work fine).

Python 3.9.6 | packaged by conda-forge | (default, Jul 11 2021, 03:35:11) 
[Clang 11.1.0 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
/Users/thomas/Documents/jax/jax/lib/__init__.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/thomas/Documents/jax/jax/__init__.py", line 37, in <module>
    from . import config as _config_module
  File "/Users/thomas/Documents/jax/jax/config.py", line 18, in <module>
    from jax._src.config import config
  File "/Users/thomas/Documents/jax/jax/_src/config.py", line 26, in <module>
    from jax import lib
  File "/Users/thomas/Documents/jax/jax/lib/__init__.py", line 70, in <module>
    from jaxlib import cpu_feature_guard
ImportError: dlopen(/Users/thomas/miniforge3/lib/python3.9/site-packages/jaxlib/cpu_feature_guard.so, 2): no suitable image found.  Did find:
	/Users/thomas/miniforge3/lib/python3.9/site-packages/jaxlib/cpu_feature_guard.so: mach-o, but wrong architecture
	/Users/thomas/miniforge3/lib/python3.9/site-packages/jaxlib/cpu_feature_guard.so: mach-o, but wrong architecture

Any ideas what could be the issue here?

Strange. I built it exactly as in dcxSt's post above and running the build script mentions "arm64" as a target. Perhaps there was some issue that I ran things accidentally in a Rosetta terminal, will double check this.

Update: I double checked, and everything was run in a regular terminal.

Update2: It runs now! :) For anyone having similar problems, my issue seems to have been that I installed bazel from a rosetta terminal.

(base) thomas@mbpro jax % python3 build/build.py                                                                

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Starting local Bazel server and connecting to it...
Bazel binary path: /usr/local/bin/bazel
Python binary path: /Users/thomas/miniforge3/bin/python3
Python version: 3.9
NumPy version: 1.21.1
SciPy version: 1.7.0
MKL-DNN enabled: yes
Target CPU: arm64
Target CPU features: default
CUDA enabled: no
TPU enabled: no
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
/usr/local/bin/bazel run --verbose_failures=true --config=short_logs --config=mkl_open_source_only :build_wheel -- --output_path=/Users/thomas/Documents/jax/dist --cpu=arm64

The resulting wheel is "jaxlib-0.1.70-cp39-none-macosx_11_0_arm64.whl", so it doesn't look like a x86 one.

I tried to install it using @dcxSt's instructions, but I get the following error:

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Extracting Bazel installation...
Starting local Bazel server and connecting to it...
Bazel binary path: /usr/local/bin/bazel
Python binary path: /Users/gerardoduran/miniforge3/envs/pyprobml/bin/python
Python version: 3.9
NumPy version: 1.19.5
SciPy version: 1.7.0
MKL-DNN enabled: yes
Target CPU: arm64
Target CPU features: release
CUDA enabled: no
TPU enabled: no
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
/usr/local/bin/bazel run --verbose_failures=true --config=short_logs --config=mkl_open_source_only :build_wheel -- --output_path=/Users/gerardoduran/Documents/repos/external/jax/dist --cpu=arm64
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /Users/gerardoduran/Documents/repos/external/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /Users/gerardoduran/Documents/repos/external/jax/.bazelrc:
  Inherited 'build' options: --repo_env PYTHON_BIN_PATH=/Users/gerardoduran/miniforge3/envs/pyprobml/bin/python --action_env=PYENV_ROOT --python_path=/Users/gerardoduran/miniforge3/envs/pyprobml/bin/python --repo_env TF_NEED_CUDA=0 --action_env TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config --distinct_host_configuration=false
INFO: Found applicable config definition build:short_logs in file /Users/gerardoduran/Documents/repos/external/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file /Users/gerardoduran/Documents/repos/external/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:macos in file /Users/gerardoduran/Documents/repos/external/jax/.bazelrc: --config=posix
INFO: Found applicable config definition build:posix in file /Users/gerardoduran/Documents/repos/external/jax/.bazelrc: --copt=-Wno-sign-compare --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
Loading: 
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
WARNING: Download from http://mirror.tensorflow.org/github.com/tensorflow/runtime/archive/d29d1ef0a65a8f9c23e1f88067ce4205d3085e87.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException GET returned 404 Not Found
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Analyzing: target //build:build_wheel (1 packages loaded, 0 targets configured)
DEBUG: Rule 'io_bazel_rules_docker' indicated that a canonical reproducible form can be obtained by modifying arguments shallow_since = "1596824487 -0400"
DEBUG: Repository io_bazel_rules_docker instantiated at:
  /Users/gerardoduran/Documents/repos/external/jax/WORKSPACE:34:10: in <toplevel>
  /private/var/tmp/_bazel_gerardoduran/fdbb951abd3afe67a47ab70535de95bf/external/org_tensorflow/tensorflow/workspace0.bzl:108:34: in workspace
  /private/var/tmp/_bazel_gerardoduran/fdbb951abd3afe67a47ab70535de95bf/external/bazel_toolchains/repositories/repositories.bzl:35:23: in repositories
Repository rule git_repository defined at:
  /private/var/tmp/_bazel_gerardoduran/fdbb951abd3afe67a47ab70535de95bf/external/bazel_tools/tools/build_defs/repo/git.bzl:199:33: in <toplevel>
Analyzing: target //build:build_wheel (14 packages loaded, 12 targets configured)
Analyzing: target //build:build_wheel (14 packages loaded, 12 targets configured)
Analyzing: target //build:build_wheel (14 packages loaded, 12 targets configured)
Analyzing: target //build:build_wheel (14 packages loaded, 12 targets configured)
Analyzing: target //build:build_wheel (47 packages loaded, 255 targets configured)
Analyzing: target //build:build_wheel (159 packages loaded, 4927 targets configured)
Analyzing: target //build:build_wheel (164 packages loaded, 6975 targets configured)
Analyzing: target //build:build_wheel (164 packages loaded, 6975 targets configured)
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/74da7ae0601728d7996e37c1f1828096e3d19103.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException GET returned 404 Not Found
INFO: Analyzed target //build:build_wheel (172 packages loaded, 14115 targets configured).

INFO: Found 1 target...
[0 / 15] [Prepa] BazelWorkspaceStatusAction stable-status.txt ... (3 actions, 0 running)
ERROR: /private/var/tmp/_bazel_gerardoduran/fdbb951abd3afe67a47ab70535de95bf/external/com_google_protobuf/BUILD:301:11: Compiling src/google/protobuf/compiler/cpp/cpp_message.cc failed: (Exit 1): cc_wrapper.sh failed: error executing command 
  (cd /private/var/tmp/_bazel_gerardoduran/fdbb951abd3afe67a47ab70535de95bf/execroot/__main__ && \
  exec env - \
    PATH=/Users/gerardoduran/miniforge3/envs/pyprobml/bin:/Users/gerardoduran/miniforge3/condabin:/Users/gerardoduran/google-cloud-sdk/bin:/usr/local/opt/ruby/bin:/Users/gerardoduran/pomo:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/Library/TeX/texbin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 \
    TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 \
  external/local_config_cc/cc_wrapper.sh -U_FORTIFY_SOURCE -fstack-protector -Wall -Wthread-safety -Wself-assign -fcolor-diagnostics -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections -fdata-sections '-std=c++11' -MD -MF bazel-out/darwin-opt/bin/external/com_google_protobuf/_objs/protoc_lib/cpp_message.d '-frandom-seed=bazel-out/darwin-opt/bin/external/com_google_protobuf/_objs/protoc_lib/cpp_message.o' -iquote external/com_google_protobuf -iquote bazel-out/darwin-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/darwin-opt/bin/external/zlib -isystem external/com_google_protobuf/src -isystem bazel-out/darwin-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/darwin-opt/bin/external/zlib -Wno-sign-compare '-std=c++14' -DHAVE_PTHREAD -DHAVE_ZLIB -Woverloaded-virtual -Wno-sign-compare -Wno-unused-function -Wno-write-strings -no-canonical-prefixes -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -c external/com_google_protobuf/src/google/protobuf/compiler/cpp/cpp_message.cc -o bazel-out/darwin-opt/bin/external/com_google_protobuf/_objs/protoc_lib/cpp_message.o)
Execution platform: @local_execution_config_platform//:platform
external/com_google_protobuf/src/google/protobuf/compiler/cpp/cpp_message.cc:1695:39: error: no member named 'kInlinedType' in 'google::protobuf::internal::FieldMetadata'
      type = internal::FieldMetadata::kInlinedType;
             ~~~~~~~~~~~~~~~~~~~~~~~~~^
external/com_google_protobuf/src/google/protobuf/compiler/cpp/cpp_message.cc:2167:41: error: no member named 'TYPE_STRING_INLINED' in namespace 'google::protobuf::internal'; did you mean 'TYPE_STRING_CORD'?
            processing_type = internal::TYPE_STRING_INLINED;
                              ~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
                                        TYPE_STRING_CORD
/usr/local/include/google/protobuf/generated_message_table_driven.h:72:3: note: 'TYPE_STRING_CORD' declared here
  TYPE_STRING_CORD = 19,
  ^
external/com_google_protobuf/src/google/protobuf/compiler/cpp/cpp_message.cc:2182:41: error: no member named 'TYPE_BYTES_INLINED' in namespace 'google::protobuf::internal'; did you mean 'TYPE_BYTES_CORD'?
            processing_type = internal::TYPE_BYTES_INLINED;
                              ~~~~~~~~~~^~~~~~~~~~~~~~~~~~
                                        TYPE_BYTES_CORD
/usr/local/include/google/protobuf/generated_message_table_driven.h:74:3: note: 'TYPE_BYTES_CORD' declared here
  TYPE_BYTES_CORD = 21,
  ^
3 errors generated.
Target //build:build_wheel failed to build
INFO: Elapsed time: 159.099s, Critical Path: 8.83s
INFO: 97 processes: 55 internal, 42 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
b''
Traceback (most recent call last):
  File "/Users/gerardoduran/Documents/repos/external/jax/build/build.py", line 604, in <module>
    main()
  File "/Users/gerardoduran/Documents/repos/external/jax/build/build.py", line 599, in main
    shell(command)
  File "/Users/gerardoduran/Documents/repos/external/jax/build/build.py", line 52, in shell
    output = subprocess.check_output(cmd)
  File "/Users/gerardoduran/miniforge3/envs/pyprobml/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/Users/gerardoduran/miniforge3/envs/pyprobml/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/local/bin/bazel', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/Users/gerardoduran/Documents/repos/external/jax/dist', '--cpu=arm64']' returned non-zero exit status 1.

I'm not sure if this error is because of the 404:

WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/74da7ae0601728d7996e37c1f1828096e3d19103.tar.gz failed: class com.google.devtools.build.lib.bazel.repository.downloader.UnrecoverableHttpException GET returned 404 Not Found

Any idea why?

@gerdm Missing Xcode could be the reason. If you haven't installed Xcode, get it from here. Make sure you open it to finish the installation. And then follow the steps in @dcxSt's comment. It worked for me.

@gerdm you should install xcode (not only cmd tools) - as mentioned #5501 (comment). Following these steps worked for me.

Hey, I acted as dcxSt advised. I had the same problem as gerdm. I managed to install everything without errors using the following sequence of actions.
Installed Xcode like here installed homebrew and Bazel like here, than I acted according to dcxSt's guide.

The problem is that import jaxlib works,
whereas import jax fails

 >>> import jax
/Users/egordanilov/Desktop/Science/Git/jax/jax/lib/__init__.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/egordanilov/Desktop/Science/Git/jax/jax/__init__.py", line 37, in <module>
    from . import config as _config_module
  File "/Users/egordanilov/Desktop/Science/Git/jax/jax/config.py", line 18, in <module>
    from jax._src.config import config
  File "/Users/egordanilov/Desktop/Science/Git/jax/jax/_src/config.py", line 27, in <module>
    from jax import lib
  File "/Users/egordanilov/Desktop/Science/Git/jax/jax/lib/__init__.py", line 70, in <module>
    from jaxlib import cpu_feature_guard
ImportError: dlopen(/usr/local/Caskroom/miniforge/base/lib/python3.9/site-packages/jaxlib/cpu_feature_guard.so, 2): no suitable image found.  Did find:
	/usr/local/Caskroom/miniforge/base/lib/python3.9/site-packages/jaxlib/cpu_feature_guard.so: mach-o, but wrong architecture
	/usr/local/Caskroom/miniforge/base/lib/python3.9/site-packages/jaxlib/cpu_feature_guard.so: mach-o, but wrong architecture

I tried to do pip install jax==0.2.10 advise from here, but had no success with it.

Does anyone have any solutions in mind ?

Hey @egorssed, I ran into the same error as you. In contrast to you, I had actually initially not paid enough attention to how to install Xcode, Homebrew and Bazel. Once I followed those steps closely, the installation of jax worked out successfully. In particular, I believe that I had previously used a version of Bazel that was not meant for ARM64 / M1 architecture.

I also wrote down all steps in detail again, perhaps this makes it easier for others as well: (don't want to take false credits here though: all steps are taken from comments above or pages that are linked to in the comments above)

Edit: Apparently, these steps not always work flawlessly. Check out the comment below for small modifications of these instructions if you run into problems ;)

  1. Install xcode through the App Store. Once it is installed, open xcode and agree to the licence terms

  2. Install Rosetta2 emulator
    From what I understand Rosetta2 is a program by Apple that translates machine language that is written for Intel hardware (which I think it is called ‘x84-64’, but not 100% sure) into machine language that is appropriate for Apple’s M1 chip (which I think it is called ‘ARM64’ but also not 100% sure). You can install it via terminal with:
    % /usr/sbin/softwareupdate --install-rosetta --agree-to-license

  3. Install homebrew
    You will need to install two versions of homebrew:

  • Homebrew for the native M1 architecture. Install via terminal with:
    % arch -x86_64 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)"
    This version will (and must) be installed under /opt/homebrew .
  • Homebrew for the Intel architecture (from my understanding programs installed with this version of homebrew will then be translated by Rosetta2 once you execute them for the first time). Install via terminal with:
    % arch -x86_64 zsh
    % cd /usr/local && mkdir homebrew
    % curl -L https://github.com/Homebrew/brew/tarball/master | tar xz --strip 1 -C homebrew
    This version will (and must) be installed under /usr/local .
  1. Install python for arm64
  • If you already installed anaconda, you will have to remove it again. (Maybe there is a way to avoid this but this is what I did. The problem with anaconda is that it does not supply packages native for ARM64. We will instead install miniforge, which is like anaconda (from my understanding) but which in addition to the regular anaconda channels also supplies packages compiled for ARM64 hardware - in fact this is the default. See: https://stackoverflow.com/questions/60532678/what-is-the-difference-between-miniconda-and-miniforge)
  • Sanity check: % which brew should return % /opt/homebrew/bin/brew
  • Install miniforge via terminal with:
    % brew install --cask miniforge
    % conda init zsh
    % conda activate
  1. Create and activate new conda environment
    % conda create --name env_jax python=3.8
    % conda activate env_jax
  2. Install numpy, scipy, and scikit-learn
    % conda install numpy scipy scikit-learn
  3. Install bazel
    % brew install bazel
  4. Compile and install jax from source
    % git clone https://github.com/google/jax
    % cd jax
    % python build/build.py
    % pip install dist/*.whl
    % pip install -e .

@soerenab, many thanks to you!
Painstakingly overcoming an urge to break my head against the wall I finally succeeded in installing Jax!

However, I still had to make some corrections to your installation algorithm.

3.Install homebrew
I had some "Fetch error" with the command
arch -x86_64 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)"
I remember that I've seen that some related scripts have been deprecated, so I had to use this one
arch -x86_64 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"

4.Install python for arm64

So there is also this usual problem with conda init zsh doing nothing and conda activate claiming that the shell is not configured.

The solution is to give an explicit source before handling Conda
source /usr/local/Caskroom/miniforge/base/bin/activate

7.Install bazel
There were also some problems with Bazel installation, but, apparently, it is not needed at all.
When you build jax, it installs Bazel on its own

 >>> python3 build/build.py
...
Downloading bazel from: https://github.com/bazelbuild/bazel/releases/download/4.1.0/bazel-4.1.0-darwin-arm64
bazel-4.1.0-darwin-arm64 [########################################] 100%
...

Friends, just wanted to let you know that @xhochy has done this PR: conda-forge/jaxlib-feedstock#60.

Please consider it community-supported and not officially-by-Google, which means also please go easy on @xhochy. The best thing you could do here is to offer to help in whatever way he deems necessary - that PR was a tour-de-force which was not easy at all. (I'm still studying the PR and don't fully understand everything that went in there myself!)

I spent some quality time with this thread getting jaxlib compiled on my new M1 mac but got it there eventually - thanks all!

Since I'm using jax for a few different projects, I decided to get the jaxlib v0.1.70 wheel building on GitHub Actions targeting arm64 (here's the workflow). I wanted to share this here in case it saves anyone time (vs. building from source) in the short term (hopefully there are official wheels soonish :D). It seems to be working for me when installed in a miniforge arm64 env, but I haven't put it through its paces too aggressively. You should be able to install it using:

python -m pip install "jax[cpu]" -f "https://dfm.io/custom-wheels/jaxlib/index.html"

Please feel free to open an issue on that repo (https://github.com/dfm/custom-wheels). Hope this is useful and sorry for the noise otherwise!

@dfm that's impressive! Can I ask, did you have some special trick for guaranteeing that you could build on arm64 hardware, or is that not necessary? I had always assumed that we'd need arm64 hardware to compile, and that Actions always gives us x86 hardware, but maybe I'm incorrect here!

@ericmjl: no, cross-compiling is supported so you should be able to build with a recent xcode on x64 and that's what I'm doing here. Testing remains an issue!

@dfm that's pretty cool. Thanks for educating me!

JAX has now released a first version of jaxlib built for mac ARM architecture (v0.1.71).

(Preferred way) Upgrade your jax version to the latest version (pip install -U jax) and then pip install jax[cpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html.

OR

You can install it via pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html

Please try it out and let us know if there are any issues :)

Note: Make sure your python version is 3.9 because that's the only supported python version for this wheel.

Note also: we don't actually test these wheels. See above. We're building them as a convenience to the community but we're relying on y'all to test them for the moment...

Hi @yashk2810!

I just followed your installation instructions (the Preferred way), but it fails to run on my end

In [1]: import jax
/Users/gerardoduran/miniforge3/lib/python3.9/site-packages/jax/lib/__init__.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: jax.numpy.sqrt(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
[1]    38628 abort      ipython

Do you know what might've caused this problem?

@gerdm Nope, we'll have to debug that one. The way the wheels are built is by cross-compilation, as described above, so perhaps someone can see if they get the same behavior when building jaxlib 0.1.71 cross-compiled (or natively on arm64)? It appears it was working for 0.1.70, since folks reported that above.

We only have intermittent access to arm64 hardware, so we appreciate community help here!

Looks like this problem was fixed in the rust repo like this: https://github.com/rust-lang/rust/pull/73086/files

@yashk2810 Interesting. We just ask LLVM what the features of the local machine are: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc;drc=2876b0ebfd0080910d9eefb36d413cedb4a6ab8e;l=58 so it's not immediately obvious what we should do here.

I can confirm that I reproduce the same error as @gerdm when building from source on my M1. I haven't tried cross compiling yet. The same build environment worked for jaxlib v0.1.70 so it looks like an issue was introduced between versions. I'm not super familiar with the whole bazel build process, so I'm not sure how helpful I can be for tracking it down, but happy to help if I can!

@dfm Well that's good news and bad news. It means our wheel build works (good!). But something obviously broke in JAX (really in XLA or LLVM, most likely... bad!)

I built this wheel from TF head and jax head: https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.72-cp39-none-macosx_11_0_arm64.whl

Can you try this out and see if the error goes away?

Thank you!

@yashk2810,

I uninstalled jaxlib and installed the wheel you shared, but I get the following error

pip install --force-reinstall ~/Downloads/jaxlib-0.1.72-cp39-none-macosx_11_0_arm64.whl
# ...
Successfully installed absl-py-0.13.0 flatbuffers-2.0 jaxlib-0.1.72 numpy-1.21.2 scipy-1.7.1 six-1.16.0
(miniforge3) ❯ pip install "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# ...
Successfully installed jax-0.2.20 jaxlib-0.1.71
(miniforge3) ❯ ipython                                                               
Python 3.9.6 | packaged by conda-forge | (default, Jul 11 2021, 03:35:11) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.26.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax
/Users/gerardoduran/miniforge3/lib/python3.9/site-packages/jax/lib/__init__.py:31:
 UserWarning: JAX on Mac ARM machines is experimental
 and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-1-cb15c4215ef7> in <module>
----> 1 import jax

~/miniforge3/lib/python3.9/site-packages/jax/__init__.py in <module>
     35 # We want the exported object to be the class, so we first import the module
     36 # to make sure a later import doesn't overwrite the class.
---> 37 from . import config as _config_module
     38 del _config_module
     39 

~/miniforge3/lib/python3.9/site-packages/jax/config.py in <module>
     16 
     17 # flake8: noqa: F401
---> 18 from jax._src.config import config

~/miniforge3/lib/python3.9/site-packages/jax/_src/config.py in <module>
     25 import warnings
     26 
---> 27 from jax import lib
     28 from jax.lib import jax_jit
     29 

~/miniforge3/lib/python3.9/site-packages/jax/lib/__init__.py in <module>
     72 
     73 from jaxlib import xla_client
---> 74 from jaxlib import lapack
     75 from jaxlib import pocketfft
     76 

jaxlib/lapack.pyx in init lapack()

~/miniforge3/lib/python3.9/site-packages/scipy/linalg/__init__.py in <module>
    193 """  # noqa: E501
    194 
--> 195 from .misc import *
    196 from .basic import *
    197 from .decomp import *

~/miniforge3/lib/python3.9/site-packages/scipy/linalg/misc.py in <module>
      1 import numpy as np
      2 from numpy.linalg import LinAlgError
----> 3 from .blas import get_blas_funcs
      4 from .lapack import get_lapack_funcs
      5 

~/miniforge3/lib/python3.9/site-packages/scipy/linalg/blas.py in <module>
    211 import functools
    212 
--> 213 from scipy.linalg import _fblas
    214 try:
    215     from scipy.linalg import _cblas

ImportError: dlopen(/Users/gerardoduran/miniforge3/lib/python3.9/site-packages/scipy/linalg/_fblas.cpython-39-darwin.so, 2): no suitable image found.  Did find:
        /Users/gerardoduran/miniforge3/lib/python3.9/site-packages/scipy/linalg/_fblas.cpython-39-darwin.so: mach-o, but wrong architecture
        /Users/gerardoduran/miniforge3/lib/python3.9/site-packages/scipy/linalg/_fblas.cpython-39-darwin.so: mach-o, but wrong architecture

This answer seems to imply that your python installation went wrong somewhere: https://stackoverflow.com/questions/39477023/error-mach-o-but-wrong-architecture-after-installing-anaconda-on-mac

Can you check that and try again?

Thank you for trying out. I (and the JAX team) appreciate it :)

There are lots of users who have hit the mach-o but wrong architecture error so doesn't look like a JAX issue to me.

@yashk2810: Thanks! Unfortunately I find that your wheel built at head reproduces the same LLVM error on my machine. I'm hoping to get a chance to try to fix it on my side, but haven't had a moment yet.

Interesting, can you paste your log and the way you installed it and the OS?

Sure! Here you go:

conda create -n jax-test python=3.9 numpy scipy
conda activate jax-test
python -m pip install https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.72-cp39-none-macosx_11_0_arm64.whl
python -m pip install jax

Then in Python:

Python 3.9.7 | packaged by conda-forge | (default, Sep  2 2021, 17:55:16)
[Clang 11.1.0 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jaxlib
>>> jaxlib.__version__
'0.1.72'
>>> import jax.numpy as jnp
/opt/homebrew/Caskroom/miniforge/base/envs/jax-test/lib/python3.9/site-packages/jax/lib/__init__.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
>>> jnp.sqrt(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      python

(All other jnp functions I've tried fail with the same error, we're not limited to sqrt...)

Output from 'conda env export':
name: jax-test
channels:
  - conda-forge
dependencies:
  - ca-certificates=2021.5.30=h4653dfc_0
  - libblas=3.9.0=11_osxarm64_openblas
  - libcblas=3.9.0=11_osxarm64_openblas
  - libcxx=12.0.1=h168391b_0
  - libgfortran=5.0.0.dev0=11_0_1_hf114ba7_23
  - libgfortran5=11.0.1.dev0=hf114ba7_23
  - liblapack=3.9.0=11_osxarm64_openblas
  - libopenblas=0.3.17=openmp_h5dd58f0_1
  - llvm-openmp=12.0.1=hf3c4609_1
  - ncurses=6.2=h9aa5885_4
  - numpy=1.21.2=py39h1f3b974_0
  - openssl=1.1.1l=h3422bc3_0
  - pip=21.2.4=pyhd8ed1ab_0
  - python=3.9.7=h54d631c_0_cpython
  - python_abi=3.9=2_cp39
  - readline=8.1=hedafd6a_0
  - scipy=1.7.0=py39h5060c3b_0
  - setuptools=58.0.4=py39h2804cbe_0
  - sqlite=3.36.0=h72a2b83_1
  - tk=8.6.11=he1e0b03_1
  - tzdata=2021a=he74cb21_1
  - wheel=0.37.0=pyhd8ed1ab_1
  - xz=5.2.5=h642e427_1
  - zlib=1.2.11=h31e879b_1009
  - pip:
    - absl-py==0.13.0
    - flatbuffers==2.0
    - jax==0.2.20
    - jaxlib==0.1.72
    - opt-einsum==3.3.0
    - six==1.16.0
prefix: /opt/homebrew/Caskroom/miniforge/base/envs/jax-test

@yashk2810: I did a tiny bit of digging and I think that @hawkinsp's worries were probably right. Knowing very little about how the build infrastructure works, my best guess is that the relevant change is that TF used to configure llvm-project manually:

https://github.com/tensorflow/tensorflow/blob/4039feeb743bc42cd0a3d8146ce63fc05d23eb8d/third_party/llvm/llvm.bzl#L310-L317

But now this is delegated to the bazel support in llvm-project directly, which doesn't seem to correctly handle this target. In particular, when compiling any of the LLVM targets, the CMake variables are no longer set correctly. For example, for the dependencies of jaxlib v0.1.70, the build variable LLVM_NATIVE_ARCH=AArch64 was set correctly, but now it is set using -DLLVM_NATIVE_ARCH="X86".

Anyways, this is probably TMI here, but I'd say that it looks like the issue lives pretty high up the tree of dependencies!

I can confirm that that logic is the culprit. If I swap out that line as follows (this also isn't the right logic, but it was a test):

-    "@bazel_tools//src/conditions:darwin": native_arch_defines("X86", "x86_64-unknown-darwin"),
+    "@bazel_tools//src/conditions:darwin": native_arch_defines("AArch64", "arm64-apple-darwin"),

Then jaxlib seems to works as expected. (For reference, here's the patch that I applied to the jax source at v0.1.71 that seemed to propagate the change I wanted: https://gist.github.com/dfm/bc2cf413bb4ad0b1d6fb11a96a406ef4)

@dfm Any idea on how to trigger that error only using jaxlib? That would be nice as a test in the conda package of jaxlib to verify that everything is working (jax and jaxlib are built separately on conda-forge).

@dfm Can you try applying the following patch to LLVM and verifying the resulting wheel works for you?

diff --git a/utils/bazel/llvm-project-overlay/llvm/config.bzl b/utils/bazel/llvm-project-overlay/llvm/config.bzl
index 514f79bcf2b6..8a8e54e844a7 100644
--- a/utils/bazel/llvm-project-overlay/llvm/config.bzl
+++ b/utils/bazel/llvm-project-overlay/llvm/config.bzl
@@ -75,7 +75,10 @@ os_defines = select({
 # TODO: We should split out host vs. target here.
 llvm_config_defines = os_defines + select({
     "@bazel_tools//src/conditions:windows": native_arch_defines("X86", "x86_64-pc-win32"),
-    "@bazel_tools//src/conditions:darwin": native_arch_defines("X86", "x86_64-unknown-darwin"),
+    "@bazel_tools//src/conditions:darwin": select({
+         "@bazel_tools//platforms:arm": native_arch_defines("AArch64", "arm64-apple-darwin"),
+         "//conditions:default": native_arch_defines("X86", "x86_64-apple-darwin"),
+     }),
     "@bazel_tools//src/conditions:linux_aarch64": native_arch_defines("AArch64", "aarch64-unknown-linux-gnu"),
     "//conditions:default": native_arch_defines("X86", "x86_64-unknown-linux-gnu"),
 }) + [

@hawkinsp: I gave that a shot and it fails because, unless I'm misunderstanding something, bazel doesn't seem to support nested selects like this (I can pull up the exact error, but I didn't save it originally). Unfortunately it also looks like there is a bug in the released versions of bazel which means that /src/conditions:darwin_arm64 doesn't work. The simplest diff that I could find to implement what I think is the correct logic (mainly copied from here) was: https://gist.github.com/dfm/845dfbd3dc1c17f75e7cb0cba7b0febb

@dfm Yeah I worried that might not work because of the nested selects. Your version looks reasonable to me. Do you want to send that to upstream LLVM (I can, if you don't want to, but you did the work!)? That's all we need to do here to get JAX fixed!

@hawkinsp: Sure, thanks! I'm happy to see if I can figure out the llvm review system and report back :D

Confirmed that it works 🎉 . Built with Bazel master, and https://github.com/freedomtan/tensorflow/tree/bazel_native_build_on_m1

>>> import platform
>>> platform.uname()
uname_result(system='Darwin', node='Josefs-MBP-2.lan', release='20.3.0', version='Darwin Kernel Version 20.3.0: Thu Jan 21 00:06:51 PST 2021; root:xnu-7195.81.3~1/RELEASE_ARM64_T8101', machine='arm64')
>>> from jax.lib import xla_client as xc
>>> xops = xc.ops
>>> c = xc.XlaBuilder("simple_scalar")
>>> param_shape = xc.Shape.array_shape(np.dtype(np.float32), ())
>>> x = xops.Parameter(c, 0, param_shape)
>>> y = xops.Sin(x)
>>> computation = c.Build()
>>> cpu_backend = xc.get_local_backend("cpu")
2021-03-07 09:26:16.684549: W external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
>>> compiled_computation = cpu_backend.compile(computation)
>>> host_input = np.array(3.0, dtype=np.float32)
>>> device_input = cpu_backend.buffer_from_pyval(host_input)
>>> device_out = compiled_computation.execute([device_input ,])
>>> device_out[0].to_py()
array(0.14112, dtype=float32)

@mattjj @hawkinsp If you want a PR I would be happy to create one, but maybe it makes more sense to wait until bazel has released a working native arm64 build and tensorflow have the necessary code in master.

Hi there,

Would you mind documenting the steps you've done to resolve the problem? I have a similar issue. So I could manage to install numpyro, Jax and jaxlib but when I import the packages I get the following warning:

/opt/homebrew/Caskroom/miniforge/base/envs/bhm-at-scale/lib/python3.9/site-packages/jax/lib/init.py:31: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see #5501 in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

and an ImportError:
ImportError: dlopen(/opt/homebrew/Caskroom/miniforge/base/envs/bhm-at-scale/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: __ZN3jax12_GLOBAL__N_121CompiledFunctionCache16kDefaultCapacityE
Referenced from: /opt/homebrew/Caskroom/miniforge/base/envs/bhm-at-scale/lib/python3.9/site-packages/jaxlib/xla_extension.so

@saminehbagheri That's a warning, not an error. JAX should work as normal, but we just want you to be aware it's pretty minimally tested at this point on ARM.

@saminehbagheri That's a warning, not an error. JAX should work as normal, but we just want you to be aware it's pretty minimally tested at this point on ARM.

Thanks for the reply. I reformulated my post. True that's a warning but I also get an import error for a missing xla_extension.so.

I'm getting the same import error (and had to set global variables as described here, for GRPCIO to install).

ImportError: dlopen(/Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: __ZN3jax12_GLOBAL__N_121CompiledFunctionCache16kDefaultCapacityE
E Referenced from: /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so
E Expected in: flat namespace
E in /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so

@annakoop @saminehbagheri: that import error is most likely because of jax and jaxlib versions that dont match.
Since the llvm problem discussed above has not been fixed make sure you build jaxlib 0.1.70 and jax 0.2.19

Thank you all for the detailed information, this is very helpful.

I was running jax/jaxlib under emulation on the m1 and started seeing this same import error with the recent versions (including the 0.1.70 and 0.2.19 combo). However, it does work if I install jaxlib 0.1.61 and jax 0.2.10 (...but then I can't use jaxopt). Does anyone know why this is happening, should we not expect to run JAX under emulation going forward?

Thanks in advance!

Quick update on the LLVM issue. I did get a tiny patch merged that should start getting us there. They sensibly did not want to merge the full patch that I shared above because it shouldn't be necessary. But, I'm now finding that there's something about the tensorflow bazel configuration which means that LLVM can't seem to figure out the correct platform even after updating to this commit and using the most recent version of bazel. I've gone down a real rabbit hole with this one and I'm still coming up empty, unfortunately!

Still having the xla issue with jaxlib 0.1.70 and jax 0.2.19, testing some different configurations...

I'm getting the same import error (and had to set global variables as described here, for GRPCIO to install).

ImportError: dlopen(/Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: __ZN3jax12_GLOBAL__N_121CompiledFunctionCache16kDefaultCapacityE E Referenced from: /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so E Expected in: flat namespace E in /Users/asdf/miniforge3/envs/tensorflow/lib/python3.9/site-packages/jaxlib/xla_extension.so

Not sure if this is the right issue, but I'm getting the same error on x86 macOS Python 3.9.7.

I've just installed JAX:

  absl-py            conda-forge/noarch::absl-py-0.14.0-pyhd8ed1ab_0
  jax                conda-forge/noarch::jax-0.2.21-pyhd8ed1ab_0
  jaxlib             conda-forge/osx-64::jaxlib-0.1.71-py39h757cd7f_0
  opt_einsum         conda-forge/noarch::opt_einsum-3.3.0-pyhd8ed1ab_1
  python-flatbuffers conda-forge/noarch::python-flatbuffers-2.0-pyhd8ed1ab_0

Any updates or progress on this front? I patched the LLVM according to the above diffs with no success.

M1 Pro and M1 Max 🥲

jaxlib wheels for 0.1.73 is live: https://pypi.org/project/jaxlib/0.1.73/#files

Can you try it out and see if the issue is fixed?

I get the 'cyclone' is not a recognized processor for this target (ignoring processor) on that wheel

Same error here, as well.

Python 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:24:55) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.24.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
/Users/nicholas/miniforge3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: a = jnp.arange(5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!

I get the same error as above:

>>> rng_key = random.PRNGKey(0)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!

I think the LLVM fix may not have landed, or it was reverted.

Yeah - TensorFlow patches out the LLVM fix because it's incompatible with their macos x86_64 builds for reasons that I don't totally understand (these discussions are happening somewhere that I don't have access to). I didn't have much luck working around this in my experiments so I'm just hoping that it gets sorted out upstream eventually :D

For now, the jaxlib==0.1.70 wheels that I built are working just fine on my M1, so I've just been using those:

python -m pip install jax jaxlib==0.1.70 -f "https://dfm.io/custom-wheels/jaxlib/index.html"

Hope this helps!

Yeah - TensorFlow patches out the LLVM fix because it's incompatible with their macos x86_64 builds for reasons that I don't totally understand (these discussions are happening somewhere that I don't have access to). I didn't have much luck working around this in my experiments so I'm just hoping that it gets sorted out upstream eventually :D

For now, the jaxlib==0.1.70 wheels that I built are working just fine on my M1, so I've just been using those:

python -m pip install jax jaxlib==0.1.70 -f "https://dfm.io/custom-wheels/jaxlib/index.html"

Hope this helps!

This is the solution to use until the fix! Thank you very much

I just submitted a fix upstream to TF: tensorflow/tensorflow@cd76ed3

I'll build jaxlib again to see if the fix works.

Or if someone can build it before I do and confirm it works, it would be very much appreciated 😃

Looks like it doesn't work.

Same error, I think the uploaded wheel for 0.1.73 was never tested.

'cyclone' is not a recognized processor for this target (ignoring processor)

I'm hoping that Macbook ARM M1/M1X is going to be a proper supported platform. Same for Windows.

@erwincoumans As we mentioned above, we don't actually have any M1 hardware ourselves on the team. So we can't test anything. The M1 build is community supported at the moment.

(I would imagine we will support the M1 build ourselves in the future, but we can't yet.)

Sorry about the breakages but we don't have capability to test as Peter said.

But thank you to everyone here who is doing the testing for us. We really appreciate it.

I am working on a fix right now to get this resolved.

So a new jaxlib wheel is ready: https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Can someone please try it out on their Mac M1 and see if it works?

pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Thank you!

Still getting the same cyclone error on a fresh install. Note I use the nightly build of scipy.

Thanks for your effort, @yashk2810. Looks like updated wheel won't install on M1.

(base) nicholas@atlanta ~ % pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl


ERROR: jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl is not a supported wheel on this platform.

Are you running python 3.9 wherever you are installing it?

Also can you try pip install -U pip and then run the install command?

Hi @yashk2810 , yes please see below:

(base) nicholas@atlanta ~ % pip install -U pip
Requirement already satisfied: pip in ./miniforge3/lib/python3.9/site-packages (21.3)
(base) nicholas@atlanta ~ % python -V
Python 3.9.5

Oh weird. Something must have been broken in my terminal. Launching a fresh zshell installation worked, however the wheel is still throwing the same error as earlier.

(base) nicholas@atlanta ~ % pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl


Collecting jaxlib==0.1.74
  Downloading https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl (36.9 MB)
     |████████████████████████████████| 36.9 MB 5.8 MB/s             
Requirement already satisfied: numpy>=1.18 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.21.2)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (2.0)
Requirement already satisfied: absl-py in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (0.14.1)
Requirement already satisfied: scipy in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.7.1)
Requirement already satisfied: six in ./miniforge3/lib/python3.9/site-packages (from absl-py->jaxlib==0.1.74) (1.16.0)
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.70
    Uninstalling jaxlib-0.1.70:
      Successfully uninstalled jaxlib-0.1.70
Successfully installed jaxlib-0.1.74
(base) nicholas@atlanta ~ % ipython
Python 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:24:55) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.24.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
/Users/nicholas/miniforge3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: a = jnp.arange(5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      ipython

Are you running python 3.9 wherever you are installing it?

Also can you try pip install -U pip and then run the install command?

Also on 3.9

That 0.1.74 wheel won't run here either.

Thread 0 Crashed:: Dispatch queue: com.apple.main-thread
0   libsystem_kernel.dylib        	0x0000000183cc0e68 __pthread_kill + 8
1   libsystem_pthread.dylib       	0x0000000183cf343c pthread_kill + 292
2   libsystem_c.dylib             	0x0000000183c3b454 abort + 124
3   xla_extension.so              	0x00000001054001cc llvm::report_fatal_error(llvm::Twine const&, bool) + 452
4   xla_extension.so              	0x0000000105400008 llvm::report_fatal_error(char const*, bool) + 56
5   xla_extension.so              	0x0000000103cba204 llvm::X86Subtarget::initSubtargetFeatures(llvm::StringRef, llvm::StringRef, llvm::StringRef) + 480
6   xla_extension.so              	0x0000000103cba3bc llvm::X86Subtarget::X86Subtarget(llvm::Triple const&, llvm::StringRef, llvm::StringRef, llvm::StringRef, llvm::X86TargetMachine const&, llvm::MaybeAlign, unsigned int, unsigned int) + 356
7   xla_extension.so              	0x0000000103cbb8c0 llvm::X86TargetMachine::getSubtargetImpl(llvm::Function const&) const + 1184
8   xla_extension.so              	0x0000000103cbba74 llvm::X86TargetMachine::getTargetTransformInfo(llvm::Function const&) + 92
9   xla_extension.so              	0x0000000104ee9c70 llvm::TargetTransformInfoWrapperPass::getTTI(llvm::Function const&) + 60

I updated the 0.1.74 wheel with a different patch in upstream TF: https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Can someone see if this works?

pip install -U pip
pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

Thank you!

Hi @yashk2810 , I really appreciate you looking into this. Unfortunately, seems like the same error crops up:

(base) nicholas@atlanta ~ % pip install -U pip
Requirement already satisfied: pip in ./miniforge3/lib/python3.9/site-packages (21.3)
Collecting pip
  Downloading pip-21.3.1-py3-none-any.whl (1.7 MB)
     |████████████████████████████████| 1.7 MB 2.9 MB/s            
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.3
    Uninstalling pip-21.3:
      Successfully uninstalled pip-21.3
Successfully installed pip-21.3.1
(base) nicholas@atlanta ~ % pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl
Collecting jaxlib==0.1.74
  Downloading https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl (37.0 MB)
     |████████████████████████████████| 37.0 MB 5.7 MB/s             
Requirement already satisfied: scipy in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.7.1)
Requirement already satisfied: absl-py in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (0.14.1)
Requirement already satisfied: numpy>=1.18 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (1.21.2)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in ./miniforge3/lib/python3.9/site-packages (from jaxlib==0.1.74) (2.0)
Requirement already satisfied: six in ./miniforge3/lib/python3.9/site-packages (from absl-py->jaxlib==0.1.74) (1.16.0)
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.70
    Uninstalling jaxlib-0.1.70:
      Successfully uninstalled jaxlib-0.1.70
Successfully installed jaxlib-0.1.74
(base) nicholas@atlanta ~ %  
(base) nicholas@atlanta ~ % ipython
Python 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:24:55) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.24.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp
/Users/nicholas/miniforge3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

In [2]: jnp.arange(5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      ipython

I've now gotten the main branch version of jaxlib building on my M1 and cross-compiling on GitHub Actions. It seems to be running fine for me, and you can try it using:

python -m pip install jax jaxlib==0.1.74 -f "https://dfm.io/custom-wheels/jaxlib/index.html"

I'm still horrifyingly patching LLVM, via TensorFlow (here's the diff). It's a bit tricky from the outside to synchronize all the moving parts, but @yashk2810 if you want to chat offline, I might be able to give some tips for getting this to work without spamming this thread that's already pretty noisy. My email is on my GitHub profile and website if you're interested!

@dfm can confirm it works here--fantastic!

@dfm it works - fantastic job!