HazyResearch / m2

Repo for "Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

training data

41924076 opened this issue · comments

hi, thank you for your nice work! Do you train your M2-BERT-128 ~ 32K (shown in the paper) on LOCO V0 or LOCO V1 training set?

thank you for your reply!

Hello, you mentioned in the paper: "For all M2-BERT configurations, we use a learning rate of 5e − 6, a true batch size of 32, 1 epoch of fine-tuning, a maximum gradient norm of 1.0, and a ratio of 32 negative passages per query-positive passage pair."

do you use random neg or bm25 hard neg?

when computing OPL loss in each step, do you use only similarity(1 query and 1 passage) , where the passage is one from 33 passage, including 32 negative passages and 1 positive passage?
or do you use both similarity(1 query and 1 positive passage) and similarity(1 query and 1 negative passage) when computing MSE in OPL loss in each step?

thank you so much!

We use random negatives for fine-tuning. At each step for OPL, we calculate the similarity between the query and the positive passage as well as the similarity between the query and a negative passage.

thank you for your reply!

Hello, I find that it's a little bit hard to make qmsum score more than 45, do you use the whole LOCO V0 as training set with no duplication, deletion or certain data proportion? : )

Besides, does the public checkpoints only go through pretraining and LoCoV0 finetuning? : )

We use all of LoCo V0 as our fine-tuning dataset with 32 negatives for every query-positive passage pair. For QMSUM and the other Tau Scrolls datasets on HuggingFace, we use the given train-validation-test split and evaluate on the validation split. The public checkpoints only go through pretraining and LoCoV0 fine-tuning. However, we plan to release an updated version of the QMSUM dataset (as well as several new datasets) in LoCoV1 soon!

Thank you for your reply! Have a nice day!

Hello, sorry to bother you. I'd like to try reproducing the training of m2bert on loco v0 or v1 with 2048 length.

According to your information and the paper, m2bert loco v0 was trained with OPL loss. However, the main branch on GitHub does not provide a training command with parameter value (https://github.com/HazyResearch/m2/blob/main/bert/EMBEDDINGS.md#training), and the code does not include OPL loss.

Additionally, I noticed that the jonsf branch (https://github.com/HazyResearch/m2/tree/jonsf-patch-1) includes OPL loss, but I'm unsure if the gather_loco_training_example in the jonsf branch is available for OPL loss.

The training script provided in the jonsf branch (https://github.com/HazyResearch/m2/blob/jonsf-patch-1/bert/EMBEDDINGS.md#training) also uses GradCache's multiple_negatives_ranking_loss instead of OPL, can this command produce good results similar to those mentioned in the paper?

If it's convenient for you, could you please provide a OPL (or multiple_negatives_ranking) loss training script that can roughly reproduce public checkpoint?

Hello, in the jonsf branch, we include both OPL and multiple negatives ranking loss (MNRL) with grad caching. You can use either for training your own checkpoints of M2-BERT-2k.

We are currently exploring improved training techniques with both loss functions so we will be sure to share which turns out better! Thanks!

thank you for your reply!!!

Hi, thank you for publishing your results and sharing your training code!
It looks like you are importing orthogonal passage loss (OPL) as sentence_transformers.losses.OrthogonalPassageLoss.
However when i check sentence transformers they don't have this loss, so I'm assuming this is a fork of sentence_transformers, that you haven't shared?

From above and your paper it sounds like you are doing something like below pseduo_code

def opl_loss(model, query, documents, labels):
    q_embedding = model(query)
    d_embeddings = model(documents) # 1 positive + 32 randomly sampled negatives
    pairwise_cosine_sim = cosine_sim(q_embedding, d_embeddings)
    loss = mse(pairwise_cosine_sim, labels)
    return loss

Does this sound about right?

Yes, we have a fork of the SentenceTransformers code base, in which we add orthogonal projection loss (OPL). We include the instructions for importing it in the M2 codebase but here is the link to the codebase.

Let me know if you have any further questions!

Thank you for the quicky reply. This is exactly what I was looking for.