Java api run slowerly than Python
benlihua opened this issue · comments
I compared java and python performance with CPU and MKL enabled, Java(20ms) is much slower than Python(10ms).
BERT parameters:
batch_size=1, seq_length=24, num_hidden_layers=6, num_attention_heads=12
ENV variables:
export OMP_NUM_THREADS=8
export KMP_BLOCKTIME=0
export KMP_AFFINITY=granularity=fine,verbose,compact,1,0
Then I tested Java Native Api(JNI) instead of JNA, it's also the same phenomenon. I suspect that it's something about cuBERT.so (maybe MKL configuration) making java run slower than python?
I think I found the problem. when set MKL_VERBOSE=1, java show "MKL_VERBOSE Intel(R) MKL 2019.0 Update 3 Product build 20190220.." and python show "MKL_VERBOSE oneMKL 2021.0 Update 4 Product build 20210904...". So the old version of MKL may show worse performance?