jerryji1993 / DNABERT

DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome

Home Page:https://doi.org/10.1093/bioinformatics/btab083

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Attention scores out of memory

flpgrz opened this issue · comments

Hi, I'm trying to execute step 5 to extract attention scores.

I use the following settings:

    --max_seq_length 1536                                          
    --per_gpu_pred_batch_size=64                                                                                    
    --fp16 

The dataset contains 89848 samples. I get the following out of memory error when the attention_scores numpy array is instantiated.

02/08/2022 09:39:01 - INFO - __main__ -   Saving features into cached file ./cached_de│
v_ft_0_1536_dnaprom                                                                   │
02/08/2022 09:41:15 - INFO - __main__ -   ***** Running prediction  *****             │
02/08/2022 09:41:15 - INFO - __main__ -     Num examples = 89848                      │
02/08/2022 09:41:15 - INFO - __main__ -     Batch size = 64                           │
Traceback (most recent call last):                                                    │
  File "run_finetune.py", line 1284, in <module>                                      │
    main()                                                                            │
  File "run_finetune.py", line 1191, in main                                          │
    attention_scores, probs = visualize(args, model, tokenizer, prefix=prefix, kmer=km│
er)                                                                                   │
  File "run_finetune.py", line 612, in visualize                                      ├
    attention_scores = np.zeros([len(pred_dataset), 12, args.max_seq_length, args.max_│
seq_length])                                                                          │
MemoryError: Unable to allocate 18.5 TiB for an array with shape (89848, 12, 1536, 153│
6) and data type float64                                                              │
                                                                                      │
*************************************

Any idea on how to solve this issue? Thanks.