Readme section 5.2
mepster opened this issue · comments
The README outlines installation and a tutorial for use.
Section 5 says:
Visualiazation of DNABERT consists of 2 steps. Calcualate attention scores and Plot.
However section 5.2 seems to be incomplete. It consists only of:
####5.2 Plotting tool
Do you have any further instructions how to use the plotting tool?
Thanks for making DNABERT available!
You can use BerViz as well for visualization. DNABert is available on the HuggingFace platform, find it here: https://huggingface.co/zhihan1996. Load it using HiggingFace then use BerViz for visualization of weights.
I assume you can load it like this:
from transformers import AutoTokenizer, AutoModel
from bertviz import head_view, model_view
tokenizer = auto_tokenizer.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', do_lower_case=False)
model = auto_model.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', output_attentions=True)
def call_html():
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
sequence = # your sequence
call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids.to(device))[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
model_view(attention, tokens)
I haven't run this code, however, I have written this code and it works with ESM and ProtTrans models, I assume it works with DNABert as well. Though there might be slight differences that you should handle, but this is the general way of doing it.
Thanks @Moeinh77 !
I would greatly appreciate your wise feedback on the following.
If I modify your code to this
%matplotlib widget
from bertviz import head_view, model_view
from transformers import AutoTokenizer, AutoModel
import IPython
# I downloaded the model to a local directory 'DNA_bert_6'
tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True)
def call_html():
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
sequence = 'CACAGCACAGCCCAGCCAAGCCAGGCCAGCCCAGCCCAGCCAAGCCACGCCACTCCACTACACTAGACTAGGCTAGGCTAGGCCAGGCCCGGCCCTGCCCTGCCCTGTCCTGTCCTGTCCTGTCCTGTCCTGTCCTGCCCTGCACTGCAGTGCAGCGCAGCCCAGCCCAGCCCCGCCCCCCCCCCTCCCCTGCCCTGTCCTGTACTGTAGTGTAGGGTAGGGTAGGGGAGGGGTGGGGTCGGGTCTGGTCTGGTCTGGTCTGGACTGGAATGGAACGGAACAGAACAGAACAGCACAGCCCAGCCAAGCCAGGCCAGGCCAGGACAGGAGAGGAGTGGAGTGGAGTGGAGTGGTGTGGTTTGGTTTGGTTTAGTTTAATTTAAGTTAAGATAAGAGAAGAGGAGAGGCGAGGCAAGGCAGGGCAGGGCAGGGCAGGGGAGGGGAGGGGAGGGGAGTGGAGTCGAGTCGAGTCGCGTCGCCTCGCCTCGCCTTGCCTTGCCTTGCCTTGCCTTGCCCTGCCCTGCCCTGCCCTGTCCTGTGCTGTGCTGTGCCGTGCCATGCCACGCCACACCACAC'
#print(sequence, len(sequence))
call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
device = 'cpu'
attention = model(input_ids.to(device))[-1]
print(attention)
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
model_view(attention, tokens)
I get the error
============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_131706/782624532.py in <module>
7 # Also note I commented out "output_attention=True"
8 tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
----> 9 model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True)
10
11 def call_html():
~/Repos/DNABERT/src/transformers/modeling_auto.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
382 for config_class, model_class in MODEL_MAPPING.items():
383 if isinstance(config, config_class):
--> 384 return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
385 raise ValueError(
386 "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
~/Repos/DNABERT/src/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
645
646 # Instantiate model.
--> 647 model = cls(config, *model_args, **model_kwargs)
648
649 if state_dict is None and not from_tf:
TypeError: __init__() got an unexpected keyword argument 'output_attention'
But if I comment out 'output_attention' like this:
model = AutoModel.from_pretrained('DNA_bert_6')#, output_attention=True)
then I get the error:
============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>
tensor([[ 0.6328, -0.1791, 0.0956, -0.5731, -0.1763, 0.0412, 0.0398, -0.1986,
-0.1466, 0.0585, -0.3806, 0.3187, 0.2130, -0.2569, 0.0940, -0.2212,
-0.1783, 0.4327, -0.1394, 0.3384, 0.0865, -0.5700, -0.2146, -0.8339,
0.5354, 0.2046, -0.1657, -0.3046, 0.0650, -0.3055, -0.1878, 0.3082,
-0.3724, 0.1021, -0.1528, 0.0377, -0.0578, -0.0887, 0.0704, -0.1101,
-0.0773, -0.7024, 0.5353, 0.3198, 0.3903, -0.2056, -0.5153, 0.4037,
0.3320, 0.1946, 0.5429, -0.2637, 0.6125, -0.0789, 0.1112, 0.7812,
0.0152, -0.0238, -0.1098, -0.4792, -0.4241, -0.0332, -0.0680, 0.0209,
0.3960, -0.3283, 0.3168, -0.3727, -0.6130, 0.4901, 0.1921, 0.1362,
0.2909, 0.3878, 0.1141, 0.2536, 0.3613, 0.1088, 0.1635, 0.6388,
0.0469, -0.0086, 0.4082, 0.0579, -0.2604, 0.0479, -0.1125, -0.0638,
0.0445, -0.3106, 0.5073, 0.7865, -0.6308, -0.1794, -0.2932, 0.3073,
0.7342, -0.5972, -0.1660, -0.2717, 0.0208, 0.4985, 0.0841, 0.2110,
-0.1882, -0.5626, -0.0462, 0.3095, 0.3285, -0.3160, -0.6916, 0.4086,
-0.4803, 0.3591, 0.7099, -0.4982, -0.0836, 0.0289, -0.0527, 0.0656,
0.0248, -0.0470, -0.1063, -0.2031, 0.1532, -0.6646, -0.2786, -0.1763,
0.2845, 0.0204, 0.0968, 0.4840, -0.1410, 0.0664, 0.7044, -0.1921,
-0.0878, 0.7824, -0.0782, -0.3539, -0.4634, 0.1302, 0.0379, -0.3873,
0.4631, 0.6339, 0.3224, -0.3064, 0.7668, -0.6782, -0.7405, 0.2909,
0.3727, 0.0092, 0.4768, -0.3835, 0.0044, 0.3539, -0.2782, 0.0300,
-0.4242, -0.0869, 0.2858, 0.1832, -0.2757, -0.0408, 0.3069, 0.3939,
-0.2805, -0.3999, -0.6797, 0.5870, 0.3561, 0.1782, 0.0386, 0.2089,
-0.4678, 0.5330, -0.0612, -0.0711, -0.3962, 0.3030, 0.2728, 0.7759,
-0.2187, 0.4681, 0.8431, -0.1215, 0.0179, -0.3706, 0.2816, 0.3015,
0.4501, -0.5754, 0.0781, 0.1056, 0.2775, -0.5127, 0.2704, -0.0778,
-0.4359, -0.3127, -0.3191, -0.0699, 0.1088, 0.1312, 0.2286, -0.1017,
-0.4808, 0.0402, -0.4942, -0.0275, 0.3092, 0.3350, 0.5882, 0.3709,
0.3021, 0.5267, 0.0873, -0.1373, -0.3230, -0.4778, -0.1138, -0.2788,
0.3652, -0.1727, 0.4334, 0.3024, -0.5430, 0.1484, 0.3245, -0.1186,
0.6870, 0.4794, 0.2478, 0.2582, -0.3006, 0.1935, -0.6847, 0.2272,
-0.2336, -0.2726, 0.3207, -0.1575, -0.4818, -0.5784, -0.1170, -0.1105,
-0.2947, 0.0261, 0.0238, -0.2571, -0.1420, 0.1536, 0.0366, -0.2715,
0.2502, -0.4401, -0.3884, 0.4987, -0.1169, -0.3366, -0.4294, 0.1843,
0.0435, -0.0389, 0.1158, 0.0505, 0.5883, -0.4889, 0.3945, 0.6702,
-0.1314, 0.5126, 0.1850, 0.0217, 0.3293, 0.3679, 0.0524, 0.4328,
0.1084, -0.0995, -0.3061, -0.4815, 0.8196, -0.1747, -0.7363, -0.3690,
-0.4089, -0.3178, -0.3008, 0.0842, 0.5494, -0.3024, 0.0686, -0.0024,
0.5566, -0.4074, -0.1000, 0.5726, 0.2140, 0.3414, 0.2051, -0.5276,
-0.0900, -0.4059, 0.6995, -0.6243, 0.1423, 0.0790, -0.5684, -0.7503,
0.0812, 0.2904, -0.0825, -0.5300, 0.2102, 0.2130, 0.1844, -0.2488,
-0.3667, -0.2187, -0.0555, 0.1509, -0.5548, -0.1680, -0.1770, -0.3859,
0.2160, 0.6033, -0.0487, 0.0178, 0.2910, 0.6036, -0.5233, 0.0991,
-0.3996, -0.5462, 0.4345, -0.0706, -0.1149, 0.6331, -0.4878, -0.2699,
0.3667, -0.1658, -0.2681, -0.4240, -0.2277, -0.0285, -0.4568, -0.5831,
0.2132, -0.4749, -0.3003, 0.2729, 0.2058, 0.2281, -0.2277, 0.4400,
0.3480, 0.4202, -0.1285, -0.0944, 0.3023, 0.1910, -0.0172, -0.3824,
0.2430, 0.6200, -0.2266, -0.2784, 0.4599, 0.2767, -0.4933, -0.0863,
0.6689, 0.5872, 0.1265, -0.2042, 0.0360, 0.5149, -0.3186, -0.2453,
-0.8713, 0.5262, 0.4512, -0.4293, 0.2160, -0.2895, 0.5623, -0.0468,
0.3609, 0.0478, 0.5212, 0.3534, 0.6020, -0.5200, -0.0011, 0.1896,
0.1710, -0.4412, 0.3991, 0.3104, -0.1117, 0.1043, 0.2124, -0.6773,
0.3843, -0.1273, 0.6962, 0.5702, -0.2138, 0.4218, -0.1075, -0.1051,
0.7056, 0.5510, -0.2203, -0.4441, -0.1577, 0.0099, 0.7935, 0.0558,
-0.3306, 0.4195, 0.2848, -0.0217, 0.3778, 0.1274, 0.0162, -0.5062,
-0.1455, -0.1946, -0.3807, -0.5371, -0.1086, -0.3016, -0.4707, -0.4386,
-0.2609, 0.2080, 0.5916, -0.1465, 0.6091, -0.3506, 0.5574, 0.1640,
-0.0207, -0.0840, -0.0200, -0.5686, 0.4613, 0.0057, -0.1011, -0.2511,
-0.6345, 0.4575, 0.0240, 0.4279, -0.5957, 0.0820, -0.5342, -0.3385,
0.2554, 0.3760, -0.2332, 0.3575, -0.6959, 0.3710, 0.3159, -0.5322,
0.7105, 0.2933, 0.1016, 0.2927, 0.1187, -0.2822, -0.2821, -0.1209,
0.0618, 0.2846, 0.0760, -0.1153, -0.0721, 0.3869, 0.1510, 0.0321,
0.7199, 0.3929, 0.2044, 0.3253, -0.1251, 0.3145, 0.2683, 0.3081,
0.1782, -0.2001, 0.1666, -0.0118, 0.1197, 0.3457, 0.7933, 0.1135,
-0.2071, 0.0576, -0.3225, -0.1244, 0.3777, 0.2173, -0.1528, 0.2772,
0.1458, -0.0610, -0.7065, 0.2422, -0.3388, -0.1959, -0.0714, 0.3345,
0.5365, 0.1304, -0.2894, 0.4952, -0.3224, 0.5075, -0.7477, -0.3850,
0.0277, -0.1673, 0.0037, -0.8024, -0.2407, -0.3995, 0.3362, -0.2044,
0.4753, 0.2201, 0.3023, 0.4583, 0.5867, -0.5060, -0.1469, 0.4736,
0.4303, -0.1274, 0.3662, 0.4385, -0.2217, 0.0379, 0.4506, 0.0816,
0.1533, 0.0871, -0.2132, -0.5684, -0.0166, 0.6519, -0.1739, 0.2253,
0.0837, 0.4929, 0.0724, 0.2116, -0.2263, -0.2795, 0.2216, 0.0188,
-0.2020, -0.7023, -0.2023, -0.2113, -0.0123, 0.1137, -0.1966, -0.1871,
0.0188, -0.0423, 0.2619, -0.3462, 0.4886, -0.0335, -0.1431, -0.6855,
0.3971, 0.0438, -0.6949, 0.4389, -0.4733, 0.4454, -0.3445, 0.0646,
-0.2358, -0.2089, 0.7600, -0.4973, -0.4342, 0.0079, 0.1723, 0.0921,
-0.0841, 0.5822, 0.2439, -0.4994, -0.1663, -0.4547, -0.1404, 0.0726,
-0.0211, 0.3611, 0.0262, -0.5103, -0.1072, -0.4839, 0.5536, -0.7744,
0.4779, -0.1998, 0.4857, 0.3890, 0.2740, 0.0436, 0.0469, 0.1837,
-0.0826, 0.3785, 0.0765, 0.2037, 0.4014, 0.2767, 0.1950, 0.0096,
-0.2013, 0.3759, -0.5088, 0.1774, -0.2251, -0.8234, -0.4353, -0.1012,
0.4806, -0.4940, 0.0107, -0.0193, 0.1832, 0.3600, 0.1512, -0.1332,
0.2977, -0.1987, 0.0201, -0.0363, -0.1626, 0.3619, -0.7445, 0.1957,
-0.3357, -0.1715, -0.4738, 0.0771, -0.1537, -0.4313, -0.0137, -0.0934,
0.4966, 0.7250, -0.5224, 0.0362, 0.3701, -0.0547, -0.0383, -0.4689,
-0.6313, 0.1891, -0.4483, 0.2346, -0.0071, 0.2558, 0.5207, -0.6604,
0.1391, -0.3340, 0.7140, -0.7954, 0.1538, 0.1664, 0.4493, -0.3938,
0.5326, 0.8514, -0.6008, 0.5339, 0.2108, 0.1381, 0.2696, -0.2429,
0.1073, 0.4450, 0.4119, -0.3839, -0.0521, -0.0234, -0.2992, -0.5740,
-0.4005, -0.1870, 0.5233, 0.6859, 0.4890, -0.1400, 0.2023, 0.3482,
-0.2123, 0.0724, -0.3865, 0.7546, -0.1868, -0.5763, -0.1203, -0.1293,
0.5864, -0.0392, 0.2688, 0.0660, 0.1680, 0.3938, 0.2407, 0.0487,
0.3158, -0.0237, -0.3317, -0.0711, -0.5653, 0.2615, 0.0772, -0.0171,
0.2291, 0.3454, 0.0991, -0.5713, -0.1992, -0.0625, -0.3561, 0.3915,
-0.3596, 0.2565, 0.0120, -0.2206, 0.4874, 0.5524, -0.2768, 0.3071,
0.4948, 0.1323, -0.4300, 0.1927, 0.2829, -0.0604, -0.2676, -0.0276,
0.3177, 0.0729, -0.6886, -0.0284, 0.3048, 0.0920, 0.2971, -0.0899]],
grad_fn=<TanhBackward0>)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
/tmp/ipykernel_131706/694873838.py in <module>
35 input_id_list = input_ids[0].tolist() # Batch index 0
36 tokens = tokenizer.convert_ids_to_tokens(input_id_list)
---> 37 model_view(attention, tokens)
~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/model_view.py in model_view(attention, tokens, sentence_b_start, prettify_tokens, display_mode, encoder_attention, decoder_attention, cross_attention, encoder_tokens, decoder_tokens, include_layers, include_heads, html_action)
60 raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
61 " argument is only for self-attention models.")
---> 62 n_heads = num_heads(attention)
63 if include_layers is None:
64 include_layers = list(range(num_layers(attention)))
~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/util.py in num_heads(attention)
24
25 def num_heads(attention):
---> 26 return attention[0][0].size(0)
27
28
IndexError: Dimension specified as 0 but tensor has no dimensions
(Note, it is not raising that ValueError on line 60 of model_view.py. The error is on line 62.)
And at that point I am stuck. Any ideas? Thanks a lot!
Hi @mepster I don't specifically know why this error is raised but I suggest trying BertModel and BertTokenizer instead of the AutoModel and AutoTokenizer. When you comment out the output_attention then there is no attention for visualizing. So avoid commenting it out and try the BertModel hopefully it will fix it up.