google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GroupedConv distributed training failure

sali1997s opened this issue · comments

  • Docker: nvcr.io/nvidia/jax:23.10-t5x-py3
  • GPU/TPU model and memory: 8 A100 - 40Gb

I think, there is a trouble with grouped convolution when i'm using distributed training.
I have tried to use a simple convolution with distributed training, that worked fine. I have tried a grouped convolution on a single device, worked fine.

Log i met:

Traceback (most recent call last):
File "/WavLMJax/distributed_grouped_conv_failure.py", line 86, in
train_step(logical_initialized_state, jax.device_put(batch, mesh_sharding(PartitionSpec('data', None, None))), model)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [hlo verifier]: Expected instruction to have shape equal to f32[128,1,384], actual shape is f32[128,48,384]:
%multiply.33 = f32[128,48,384]{2,1,0} multiply(f32[128,1,384]{2,1,0} %dynamic-slice.6, f32[128,1,384]{2,1,0} %dynamic-slice.6), metadata={op_name="jit(train_step)/jit(main)/mul" source_file="/usr/local/lib/python3.10/dist-packages/optax/_src/transform.py" source_line=98}
Failed after pipeline-start

Code:
https://colab.research.google.com/drive/117FrrCLar8TVcXncT8kUsZEykallgEqX?usp=sharing (distributed grouped conv)
https://colab.research.google.com/drive/1xmvMAfz4NzNmp7EAxF8EIysGYAsP_jCV?usp=sharing (single device grouped conv)