luyug / Condenser

EMNLP 2021 - Pre-training architectures for dense retrieval

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

cocondenser-marco pretrainning data

1024er opened this issue · comments

Hi,
I am trying to reproduce cocondenser on msmarco data. But I got 37.4 on msmarco-dev task,will you help me ?
The msmarco data is downloaded from https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz (22 GB), and I extracted spans with the following code:

`def encode_one(line):
spans = nltk.sent_tokenize(line.strip())
if len(spans) < 2:
    return None
tokenized = [
    tokenizer(
        s,
        add_special_tokens=False,
        truncation=False,
        return_attention_mask=False,
        return_token_type_ids=False,
    )["input_ids"] for s in spans
]
tokenized = [span for span in tokenized if len(span) > 0]
return json.dumps({'spans': tokenized})`

The hyperparameters are following (on 16x 3090)

python -m torch.distributed.launch --nnodes=2 --node_rank=$1 --master_addr 10.104.91.11 --master_port 2222 --nproc_per_node 8 run_co_pre_training.py
--output_dir coco_msmarco_output_1e-4_bs2048
--model_name_or_path Luyu/condenser
--do_train
--fp16
--save_steps 2000
--save_total_limit 10
--model_type bert
--per_device_train_batch_size 128
--gradient_accumulation_steps 1
--warmup_ratio 0.1
--learning_rate 1e-4
--num_train_epochs 8
--dataloader_drop_last
--overwrite_output_dir
--dataloader_num_workers 32
--n_head_layers 2
--skip_from 6
--max_seq_length 128
--train_path pretrain_data/msmarco/msmarco.json
--weight_decay 0.01
--late_mlm
--cache_chunk_size 32

At the end of training, logs:
{'loss': 13.2283, 'learning_rate': 1.5294203533362537e-05, 'epoch': 6.9}
{'loss': 13.2219, 'learning_rate': 1.0731493648707841e-05, 'epoch': 7.23}
{'loss': 13.2096, 'learning_rate': 6.168783764053146e-06, 'epoch': 7.56}
{'loss': 13.1718, 'learning_rate': 1.6060738793984525e-06, 'epoch': 7.88}


On msmarco dev:
MRR @10: 0.3741504639104922
QueriesRanked: 6980
recall@1: 0.251432664756447
recall@50: 0.6607449856733524
recall@all: 0.6607449856733524
#####################

@luyug will you please help to check if the pretrain data is the same as you used ?

A few things,

  • We try to group short sentences into longer spans.
  • The code was not designed with multi-node training in mind. Single node was assumed. I think you need to check and make sure the code still works as intended on two nodes.
  • We initialized with a fully trained condenser, while you are using a headless weight --model_name_or_path Luyu/condenser.

A few things,

  • We try to group short sentences into longer spans.
  • The code was not designed with multi-node training in mind. Single node was assumed. I think you need to check and make sure the code still works as intended on two nodes.
  • We initialized with a fully trained condenser, while you are using a headless weight --model_name_or_path Luyu/condenser.

I made some modifications to my process following your advices:
(1) merged adjacent spans to maxlen (128)
(2) ran the code with 8x3090 on one node
(3) downloaded the condenser.tar.gz with head weights from the server link, changed --model_name_or_path to it.

Howerver, the final performance is still obvious lower than your provided cocondenser model:

#####################
MRR @10: 0.3765704734615908
QueriesRanked: 6980
recall@1: 0.25128939828080227
recall@50: 0.6656160458452722
recall@all: 0.6656160458452722
#####################

Pretrain log is here:

[INFO|trainer.py:358] 2022-04-01 21:50:57,905 >> Using amp fp16 backend
[INFO|trainer.py:791] 2022-04-01 21:50:59,341 >> ***** Running training *****
[INFO|trainer.py:792] 2022-04-01 21:50:59,341 >>   Num examples = 2993765
[INFO|trainer.py:793] 2022-04-01 21:50:59,341 >>   Num Epochs = 8
[INFO|trainer.py:794] 2022-04-01 21:50:59,341 >>   Instantaneous batch size per device = 256
[INFO|trainer.py:795] 2022-04-01 21:50:59,341 >>   Total train batch size (w. parallel, distributed & accumulation) = 2048
[INFO|trainer.py:796] 2022-04-01 21:50:59,341 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:797] 2022-04-01 21:50:59,341 >>   Total optimization steps = 11688
...
...
{'loss': 8.8922, 'learning_rate': 1.1293634496919918e-05, 'epoch': 7.19}
{'loss': 8.8489, 'learning_rate': 6.540421324815575e-06, 'epoch': 7.53}
{'loss': 8.8576, 'learning_rate': 1.7872081527112328e-06, 'epoch': 7.87}
100%|██████████| 11688/11688 [14:19:55<00:00,  4.41s/it]`

The span merge code is following:

def encode_one(line, maxlen=128):
    spans = nltk.sent_tokenize(line.strip())
    tokenized = [
        tokenizer(
            s,
            add_special_tokens=False,
            truncation=False,
            return_attention_mask=False,
            return_token_type_ids=False,
        )["input_ids"] for s in spans
    ]
    tokenized_spans = []
    tokenized_span = []
    for span in tokenized:
        if len(span) > 0:
            if len(span) + len(tokenized_span) > maxlen:
                if len(tokenized_span) > 0:
                    tokenized_spans.append(tokenized_span)
                tokenized_span = []
            tokenized_span.extend(span)
    if len(tokenized_span) > 0:
        tokenized_spans.append(tokenized_span)
    if len(tokenized_spans) < 2:
        return None
    return json.dumps({'spans': tokenized_spans})

@luyug

@luyug
Do you have time to look at this question? Thank you so much ~

Sorry for the delayed reply.. Too many things to deal with.

I am not super sure about what is going on. Maybe this is only because of stochastic nature of deep neural network training.

Unfortunately, I have lost the original pre-processing script in one of our cluster's upgrade so I don't have exact command to compare with. If you have the resource, I'd suggest trying one/some of the following things:

  • Lower the span size to about 80 which I think align better with the actual ms marco passages.
  • Keep a 10 percent short sentence probability.
  • Train longer. Maybe your random seed just requires longer to converge.