TypeError: take_along_axis indices must be of integer type, got float32
wamiq-reyaz opened this issue · comments
Hey,
I was trying to run the untouched notebook in an evironment with jax-0.3.13
and jaxli-0.3.10
on an Ubuntu 18.04 machine with CUDA11.7 and CUDNN 8.2, but I get the error
TypeError: take_along_axis indices must be of integer type, got float32
when running
elif run_mode == 'pmap': sample_rngs = jax.random.split(sample_rng, jax.local_device_count()) results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
Any help?
Same issue for me in Colab
Same issue
Got the same error and traced it down to line 154 in parallel_decode.py
mask_len = jnp.maximum(
1,
jnp.minimum(jnp.sum(unknown_map, axis=-1, keepdims=True) - 1, mask_len))
I patched it with changing it to
mask_len = jnp.maximum(
1,
jnp.minimum(jnp.sum(unknown_map, axis=-1, keepdims=True) - 1, mask_len)).astype('int32')
No idea if this makes any sense, but it made that error message at least go away (but later on I am getting a "Original error: UNIMPLEMENTED: DNN library is not found." one, so this might not be the correct solution)
+1
Reference to PR #9