jerryji1993 / DNABERT

DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome

Home Page:https://doi.org/10.1093/bioinformatics/btab083

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Readme section 5.2

mepster opened this issue · comments

The README outlines installation and a tutorial for use.

Section 5 says:

Visualiazation of DNABERT consists of 2 steps. Calcualate attention scores and Plot.

However section 5.2 seems to be incomplete. It consists only of:

####5.2 Plotting tool

Do you have any further instructions how to use the plotting tool?

Thanks for making DNABERT available!

P.S. The following generates one plot for me. Not sure I know what it is. :-)

pip install matplotlib
pip install seaborn

cd data_process_template

export KMER=6
export MODEL_PATH=../ft/$KMER

python ../visualize.py --model_path $MODEL_PATH --kmer $KMER

plot

You can use BerViz as well for visualization. DNABert is available on the HuggingFace platform, find it here: https://huggingface.co/zhihan1996. Load it using HiggingFace then use BerViz for visualization of weights.
I assume you can load it like this:

from transformers import AutoTokenizer, AutoModel
from bertviz import head_view, model_view

tokenizer = auto_tokenizer.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', do_lower_case=False)
model = auto_model.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', output_attentions=True) 

def call_html():
    display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

sequence = # your sequence 
call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids.to(device))[-1]

input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
model_view(attention, tokens)

I haven't run this code, however, I have written this code and it works with ESM and ProtTrans models, I assume it works with DNABert as well. Though there might be slight differences that you should handle, but this is the general way of doing it.

Thanks @Moeinh77 !

I would greatly appreciate your wise feedback on the following.

If I modify your code to this

%matplotlib widget
from bertviz import head_view, model_view
from transformers import AutoTokenizer, AutoModel
import IPython

# I downloaded the model to a local directory 'DNA_bert_6'
tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True) 

def call_html():
    display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

sequence = 'CACAGCACAGCCCAGCCAAGCCAGGCCAGCCCAGCCCAGCCAAGCCACGCCACTCCACTACACTAGACTAGGCTAGGCTAGGCCAGGCCCGGCCCTGCCCTGCCCTGTCCTGTCCTGTCCTGTCCTGTCCTGTCCTGCCCTGCACTGCAGTGCAGCGCAGCCCAGCCCAGCCCCGCCCCCCCCCCTCCCCTGCCCTGTCCTGTACTGTAGTGTAGGGTAGGGTAGGGGAGGGGTGGGGTCGGGTCTGGTCTGGTCTGGTCTGGACTGGAATGGAACGGAACAGAACAGAACAGCACAGCCCAGCCAAGCCAGGCCAGGCCAGGACAGGAGAGGAGTGGAGTGGAGTGGAGTGGTGTGGTTTGGTTTGGTTTAGTTTAATTTAAGTTAAGATAAGAGAAGAGGAGAGGCGAGGCAAGGCAGGGCAGGGCAGGGCAGGGGAGGGGAGGGGAGGGGAGTGGAGTCGAGTCGAGTCGCGTCGCCTCGCCTCGCCTTGCCTTGCCTTGCCTTGCCTTGCCCTGCCCTGCCCTGCCCTGTCCTGTGCTGTGCTGTGCCGTGCCATGCCACGCCACACCACAC'
#print(sequence, len(sequence))

call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
device = 'cpu'
attention = model(input_ids.to(device))[-1]
print(attention)

input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
model_view(attention, tokens)

I get the error

============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_131706/782624532.py in <module>
      7 # Also note I commented out "output_attention=True"
      8 tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
----> 9 model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True)
     10 
     11 def call_html():

~/Repos/DNABERT/src/transformers/modeling_auto.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    382         for config_class, model_class in MODEL_MAPPING.items():
    383             if isinstance(config, config_class):
--> 384                 return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
    385         raise ValueError(
    386             "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"

~/Repos/DNABERT/src/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    645 
    646         # Instantiate model.
--> 647         model = cls(config, *model_args, **model_kwargs)
    648 
    649         if state_dict is None and not from_tf:

TypeError: __init__() got an unexpected keyword argument 'output_attention'

But if I comment out 'output_attention' like this:

model = AutoModel.from_pretrained('DNA_bert_6')#, output_attention=True) 

then I get the error:

============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>

tensor([[ 0.6328, -0.1791,  0.0956, -0.5731, -0.1763,  0.0412,  0.0398, -0.1986,
         -0.1466,  0.0585, -0.3806,  0.3187,  0.2130, -0.2569,  0.0940, -0.2212,
         -0.1783,  0.4327, -0.1394,  0.3384,  0.0865, -0.5700, -0.2146, -0.8339,
          0.5354,  0.2046, -0.1657, -0.3046,  0.0650, -0.3055, -0.1878,  0.3082,
         -0.3724,  0.1021, -0.1528,  0.0377, -0.0578, -0.0887,  0.0704, -0.1101,
         -0.0773, -0.7024,  0.5353,  0.3198,  0.3903, -0.2056, -0.5153,  0.4037,
          0.3320,  0.1946,  0.5429, -0.2637,  0.6125, -0.0789,  0.1112,  0.7812,
          0.0152, -0.0238, -0.1098, -0.4792, -0.4241, -0.0332, -0.0680,  0.0209,
          0.3960, -0.3283,  0.3168, -0.3727, -0.6130,  0.4901,  0.1921,  0.1362,
          0.2909,  0.3878,  0.1141,  0.2536,  0.3613,  0.1088,  0.1635,  0.6388,
          0.0469, -0.0086,  0.4082,  0.0579, -0.2604,  0.0479, -0.1125, -0.0638,
          0.0445, -0.3106,  0.5073,  0.7865, -0.6308, -0.1794, -0.2932,  0.3073,
          0.7342, -0.5972, -0.1660, -0.2717,  0.0208,  0.4985,  0.0841,  0.2110,
         -0.1882, -0.5626, -0.0462,  0.3095,  0.3285, -0.3160, -0.6916,  0.4086,
         -0.4803,  0.3591,  0.7099, -0.4982, -0.0836,  0.0289, -0.0527,  0.0656,
          0.0248, -0.0470, -0.1063, -0.2031,  0.1532, -0.6646, -0.2786, -0.1763,
          0.2845,  0.0204,  0.0968,  0.4840, -0.1410,  0.0664,  0.7044, -0.1921,
         -0.0878,  0.7824, -0.0782, -0.3539, -0.4634,  0.1302,  0.0379, -0.3873,
          0.4631,  0.6339,  0.3224, -0.3064,  0.7668, -0.6782, -0.7405,  0.2909,
          0.3727,  0.0092,  0.4768, -0.3835,  0.0044,  0.3539, -0.2782,  0.0300,
         -0.4242, -0.0869,  0.2858,  0.1832, -0.2757, -0.0408,  0.3069,  0.3939,
         -0.2805, -0.3999, -0.6797,  0.5870,  0.3561,  0.1782,  0.0386,  0.2089,
         -0.4678,  0.5330, -0.0612, -0.0711, -0.3962,  0.3030,  0.2728,  0.7759,
         -0.2187,  0.4681,  0.8431, -0.1215,  0.0179, -0.3706,  0.2816,  0.3015,
          0.4501, -0.5754,  0.0781,  0.1056,  0.2775, -0.5127,  0.2704, -0.0778,
         -0.4359, -0.3127, -0.3191, -0.0699,  0.1088,  0.1312,  0.2286, -0.1017,
         -0.4808,  0.0402, -0.4942, -0.0275,  0.3092,  0.3350,  0.5882,  0.3709,
          0.3021,  0.5267,  0.0873, -0.1373, -0.3230, -0.4778, -0.1138, -0.2788,
          0.3652, -0.1727,  0.4334,  0.3024, -0.5430,  0.1484,  0.3245, -0.1186,
          0.6870,  0.4794,  0.2478,  0.2582, -0.3006,  0.1935, -0.6847,  0.2272,
         -0.2336, -0.2726,  0.3207, -0.1575, -0.4818, -0.5784, -0.1170, -0.1105,
         -0.2947,  0.0261,  0.0238, -0.2571, -0.1420,  0.1536,  0.0366, -0.2715,
          0.2502, -0.4401, -0.3884,  0.4987, -0.1169, -0.3366, -0.4294,  0.1843,
          0.0435, -0.0389,  0.1158,  0.0505,  0.5883, -0.4889,  0.3945,  0.6702,
         -0.1314,  0.5126,  0.1850,  0.0217,  0.3293,  0.3679,  0.0524,  0.4328,
          0.1084, -0.0995, -0.3061, -0.4815,  0.8196, -0.1747, -0.7363, -0.3690,
         -0.4089, -0.3178, -0.3008,  0.0842,  0.5494, -0.3024,  0.0686, -0.0024,
          0.5566, -0.4074, -0.1000,  0.5726,  0.2140,  0.3414,  0.2051, -0.5276,
         -0.0900, -0.4059,  0.6995, -0.6243,  0.1423,  0.0790, -0.5684, -0.7503,
          0.0812,  0.2904, -0.0825, -0.5300,  0.2102,  0.2130,  0.1844, -0.2488,
         -0.3667, -0.2187, -0.0555,  0.1509, -0.5548, -0.1680, -0.1770, -0.3859,
          0.2160,  0.6033, -0.0487,  0.0178,  0.2910,  0.6036, -0.5233,  0.0991,
         -0.3996, -0.5462,  0.4345, -0.0706, -0.1149,  0.6331, -0.4878, -0.2699,
          0.3667, -0.1658, -0.2681, -0.4240, -0.2277, -0.0285, -0.4568, -0.5831,
          0.2132, -0.4749, -0.3003,  0.2729,  0.2058,  0.2281, -0.2277,  0.4400,
          0.3480,  0.4202, -0.1285, -0.0944,  0.3023,  0.1910, -0.0172, -0.3824,
          0.2430,  0.6200, -0.2266, -0.2784,  0.4599,  0.2767, -0.4933, -0.0863,
          0.6689,  0.5872,  0.1265, -0.2042,  0.0360,  0.5149, -0.3186, -0.2453,
         -0.8713,  0.5262,  0.4512, -0.4293,  0.2160, -0.2895,  0.5623, -0.0468,
          0.3609,  0.0478,  0.5212,  0.3534,  0.6020, -0.5200, -0.0011,  0.1896,
          0.1710, -0.4412,  0.3991,  0.3104, -0.1117,  0.1043,  0.2124, -0.6773,
          0.3843, -0.1273,  0.6962,  0.5702, -0.2138,  0.4218, -0.1075, -0.1051,
          0.7056,  0.5510, -0.2203, -0.4441, -0.1577,  0.0099,  0.7935,  0.0558,
         -0.3306,  0.4195,  0.2848, -0.0217,  0.3778,  0.1274,  0.0162, -0.5062,
         -0.1455, -0.1946, -0.3807, -0.5371, -0.1086, -0.3016, -0.4707, -0.4386,
         -0.2609,  0.2080,  0.5916, -0.1465,  0.6091, -0.3506,  0.5574,  0.1640,
         -0.0207, -0.0840, -0.0200, -0.5686,  0.4613,  0.0057, -0.1011, -0.2511,
         -0.6345,  0.4575,  0.0240,  0.4279, -0.5957,  0.0820, -0.5342, -0.3385,
          0.2554,  0.3760, -0.2332,  0.3575, -0.6959,  0.3710,  0.3159, -0.5322,
          0.7105,  0.2933,  0.1016,  0.2927,  0.1187, -0.2822, -0.2821, -0.1209,
          0.0618,  0.2846,  0.0760, -0.1153, -0.0721,  0.3869,  0.1510,  0.0321,
          0.7199,  0.3929,  0.2044,  0.3253, -0.1251,  0.3145,  0.2683,  0.3081,
          0.1782, -0.2001,  0.1666, -0.0118,  0.1197,  0.3457,  0.7933,  0.1135,
         -0.2071,  0.0576, -0.3225, -0.1244,  0.3777,  0.2173, -0.1528,  0.2772,
          0.1458, -0.0610, -0.7065,  0.2422, -0.3388, -0.1959, -0.0714,  0.3345,
          0.5365,  0.1304, -0.2894,  0.4952, -0.3224,  0.5075, -0.7477, -0.3850,
          0.0277, -0.1673,  0.0037, -0.8024, -0.2407, -0.3995,  0.3362, -0.2044,
          0.4753,  0.2201,  0.3023,  0.4583,  0.5867, -0.5060, -0.1469,  0.4736,
          0.4303, -0.1274,  0.3662,  0.4385, -0.2217,  0.0379,  0.4506,  0.0816,
          0.1533,  0.0871, -0.2132, -0.5684, -0.0166,  0.6519, -0.1739,  0.2253,
          0.0837,  0.4929,  0.0724,  0.2116, -0.2263, -0.2795,  0.2216,  0.0188,
         -0.2020, -0.7023, -0.2023, -0.2113, -0.0123,  0.1137, -0.1966, -0.1871,
          0.0188, -0.0423,  0.2619, -0.3462,  0.4886, -0.0335, -0.1431, -0.6855,
          0.3971,  0.0438, -0.6949,  0.4389, -0.4733,  0.4454, -0.3445,  0.0646,
         -0.2358, -0.2089,  0.7600, -0.4973, -0.4342,  0.0079,  0.1723,  0.0921,
         -0.0841,  0.5822,  0.2439, -0.4994, -0.1663, -0.4547, -0.1404,  0.0726,
         -0.0211,  0.3611,  0.0262, -0.5103, -0.1072, -0.4839,  0.5536, -0.7744,
          0.4779, -0.1998,  0.4857,  0.3890,  0.2740,  0.0436,  0.0469,  0.1837,
         -0.0826,  0.3785,  0.0765,  0.2037,  0.4014,  0.2767,  0.1950,  0.0096,
         -0.2013,  0.3759, -0.5088,  0.1774, -0.2251, -0.8234, -0.4353, -0.1012,
          0.4806, -0.4940,  0.0107, -0.0193,  0.1832,  0.3600,  0.1512, -0.1332,
          0.2977, -0.1987,  0.0201, -0.0363, -0.1626,  0.3619, -0.7445,  0.1957,
         -0.3357, -0.1715, -0.4738,  0.0771, -0.1537, -0.4313, -0.0137, -0.0934,
          0.4966,  0.7250, -0.5224,  0.0362,  0.3701, -0.0547, -0.0383, -0.4689,
         -0.6313,  0.1891, -0.4483,  0.2346, -0.0071,  0.2558,  0.5207, -0.6604,
          0.1391, -0.3340,  0.7140, -0.7954,  0.1538,  0.1664,  0.4493, -0.3938,
          0.5326,  0.8514, -0.6008,  0.5339,  0.2108,  0.1381,  0.2696, -0.2429,
          0.1073,  0.4450,  0.4119, -0.3839, -0.0521, -0.0234, -0.2992, -0.5740,
         -0.4005, -0.1870,  0.5233,  0.6859,  0.4890, -0.1400,  0.2023,  0.3482,
         -0.2123,  0.0724, -0.3865,  0.7546, -0.1868, -0.5763, -0.1203, -0.1293,
          0.5864, -0.0392,  0.2688,  0.0660,  0.1680,  0.3938,  0.2407,  0.0487,
          0.3158, -0.0237, -0.3317, -0.0711, -0.5653,  0.2615,  0.0772, -0.0171,
          0.2291,  0.3454,  0.0991, -0.5713, -0.1992, -0.0625, -0.3561,  0.3915,
         -0.3596,  0.2565,  0.0120, -0.2206,  0.4874,  0.5524, -0.2768,  0.3071,
          0.4948,  0.1323, -0.4300,  0.1927,  0.2829, -0.0604, -0.2676, -0.0276,
          0.3177,  0.0729, -0.6886, -0.0284,  0.3048,  0.0920,  0.2971, -0.0899]],
       grad_fn=<TanhBackward0>)

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_131706/694873838.py in <module>
     35 input_id_list = input_ids[0].tolist() # Batch index 0
     36 tokens = tokenizer.convert_ids_to_tokens(input_id_list)
---> 37 model_view(attention, tokens)

~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/model_view.py in model_view(attention, tokens, sentence_b_start, prettify_tokens, display_mode, encoder_attention, decoder_attention, cross_attention, encoder_tokens, decoder_tokens, include_layers, include_heads, html_action)
     60             raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
     61                              " argument is only for self-attention models.")
---> 62         n_heads = num_heads(attention)
     63         if include_layers is None:
     64             include_layers = list(range(num_layers(attention)))

~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/util.py in num_heads(attention)
     24 
     25 def num_heads(attention):
---> 26     return attention[0][0].size(0)
     27 
     28 

IndexError: Dimension specified as 0 but tensor has no dimensions


(Note, it is not raising that ValueError on line 60 of model_view.py. The error is on line 62.)

And at that point I am stuck. Any ideas? Thanks a lot!

Hi @mepster I don't specifically know why this error is raised but I suggest trying BertModel and BertTokenizer instead of the AutoModel and AutoTokenizer. When you comment out the output_attention then there is no attention for visualizing. So avoid commenting it out and try the BertModel hopefully it will fix it up.