google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Metal : fp64 operations with jax.numpy base functions not supported

aboucaud opened this issue · comments

Description

Running the following

import jax
jax.config.update("jax_enable_x64", True)
jax.numpy.linspace(0, 1, 10)

produces

XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
...
<unknown>:0: error: failed to legalize operation 'func.func'

Also tested with logspace, arange, etc.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.11.4 (main, Jun 19 2023, 22:36:35) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='altair', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')

Fp64 is not supported in jax-metal backend and is unknown when the support will be there as of now. We will post the update here if the situation changes.

Previous report: #16435

Thanks for the prompt reply. And sorry the previous issue did not catch my eye.

If I may suggest, it would be relevant to have such information displayed on
https://developer.apple.com/metal/jax/