borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt

Home Page:https://www.craiyon.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training Error: TypeError: true_fun and false_fun output must have identical types

Netruk44 opened this issue · comments

Hello,

I'm attempting to fine tune dalle-mega but hit this error while trying to stand things up:

Traceback (most recent call last):
  File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1742, in <module>
    main()
  File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1702, in main
    state, train_metrics = p_train_step(state, batch, train_time)
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/experimental/pjit.py", line 367, in wrapped
    args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/experimental/pjit.py", line 344, in infer_params
    jaxpr, normalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/experimental/pjit.py", line 568, in _pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
  File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1299, in train_step
    gradients_norm = maybe_fn(
  File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1283, in maybe_fn
    return jax.lax.cond(
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/lax/control_flow/conditionals.py", line 254, in cond
    return _cond(*args, **kwargs)
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/lax/control_flow/conditionals.py", line 223, in _cond
    _check_tree_and_avals("true_fun and false_fun output",
  File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 105, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
FrozenDict({
    lm_head: {
        kernel: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
    },
    model: {
        decoder: {
            embed_positions: {
                embedding: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
            },
            embed_tokens: {
                embedding: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
            },
            final_ln: {
                bias: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
            },
            layernorm_embedding: {
                bias: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
                scale: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
            },
        },
    },
}).

I believe that I have worked around the error by changing this line: https://github.com/borisdayma/dalle-mini/blob/main/tools/train/train.py#L1294

# Old
            zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float32(0), params)
#                                                                  ^^
# Fixed
            zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float16(0), params)
#                                                                  ^^

However, I'm not very familiar with this code or what it's doing. So I don't know if this is the 'correct' solution, but I thought I would at least mention the problem to see if it's anything that might need to be addressed. It's possible I'm configuring something wrong somewhere, so this might just be a personal problem lol.

It looks like the error is likely coming from the choice of checkpoint I passed into train.py using model_name_or_path.

Starting fine-tuning using checkpoint dalle-mini/dalle-mini/mega-1-fp16:latest, I get the error mentioned, but if I fine-tine checkpoint dalle-mini/dalle-mini/mega-1:latest, it works as-is without any modifications.

I'll leave this issue open in case you'd like to do something with it, but I'd also be fine with just closing it 😄

Oh interesting, thanks for reporting it