google-research / maskgit

Official Jax Implementation of MaskGIT

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: take_along_axis indices must be of integer type, got float32

wamiq-reyaz opened this issue · comments

Hey,

I was trying to run the untouched notebook in an evironment with jax-0.3.13 and jaxli-0.3.10 on an Ubuntu 18.04 machine with CUDA11.7 and CUDNN 8.2, but I get the error

TypeError: take_along_axis indices must be of integer type, got float32

when running
elif run_mode == 'pmap': sample_rngs = jax.random.split(sample_rng, jax.local_device_count()) results = p_generate_256_samples(pmap_input_tokens, sample_rngs)

Any help?

Same issue for me in Colab

Same issue

Got the same error and traced it down to line 154 in parallel_decode.py

mask_len = jnp.maximum(
        1,
        jnp.minimum(jnp.sum(unknown_map, axis=-1, keepdims=True) - 1, mask_len))

I patched it with changing it to

mask_len = jnp.maximum(
        1,
        jnp.minimum(jnp.sum(unknown_map, axis=-1, keepdims=True) - 1, mask_len)).astype('int32') 

No idea if this makes any sense, but it made that error message at least go away (but later on I am getting a "Original error: UNIMPLEMENTED: DNN library is not found." one, so this might not be the correct solution)

+1

Reference to PR #9