texttron / tevatron

Tevatron - A flexible toolkit for neural retrieval research and development.

Home Page:http://tevatron.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

long query throw "Dtype object" due to predefined max_length in the batch

salrowili opened this issue · comments

commented

I think there is a bug in src/tevatron/driver/jax_train.py this line :

https://github.com/texttron/tevatron/blob/0e939457444f78284ab0471da74a0c74bc76a833/src/tevatron/driver/jax_train.py#L147C43-L147C56

The issue is caused by defining the max_length to 32, assuming all queries will not exceed this length, and that creates a problem when we choose data_args.q_max_len >32. I have a custom dataset with a couple of examples where queries even reach ~ 128 max_length. It would be great if you could fix this issue because the error thrown by python3 is tricky and has no indication that the cause of the problem is due to this line. I have spent two days just to realize that this line is the root of the problem. I fixed the issue by setting the max_length to 128 instead of 32. I think one solution would be just to replace 32 with data_args.q_max_len :

 return dict(tokenizer.pad(qq, max_length=data_args.q_max_len, padding='max_length', return_tensors='np')), dict(
                tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np'))

Thank you
Sultan