dmlc / gluon-nlp

NLP made easy

Home Page:https://nlp.gluon.ai/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[TVM Integration] Support TVM conversion with FP16 data type

sxjscience opened this issue · comments

Description

We fixed TVM integration in GluonNLP recently for fp32 dtype. However, we still do not support fp16 dtype. We should

  • Revise the test to add FP16:
    @pytest.mark.serial
    @pytest.mark.seed(123)
    @pytest.mark.parametrize('model_name',
    ['google_albert_base_v2',
    'google_en_cased_bert_base',
    'google_electra_small',
    'fairseq_bart_base'])
    @pytest.mark.parametrize('batch_size,seq_length', [(2, 4), (1, 4)])
    @pytest.mark.parametrize('layout', ['NT', 'TN'])
    @pytest.mark.skipif(not tvm_enabled(),
    reason='TVM is not supported. So this test is skipped.')
    # @pytest.mark.skip('TVM issue https://github.com/dmlc/gluon-nlp/issues/1425.')
    def test_tvm_integration(model_name, batch_size, seq_length, layout, ctx):
  • Revise the benchmark to add TVM + FP16:
    def compile_tvm_graph_runtime(model, model_name, layout, compute_layout,
    batch_size, seq_length, dtype, instance_type):
    key = (model_name, layout, compute_layout, batch_size, seq_length, dtype, instance_type)
    if key in _TVM_RT_CACHE:
    return _TVM_RT_CACHE[key]
    tvm = try_import_tvm()
    from tvm import relay
    from tvm.contrib import graph_runtime
    from gluonnlp.utils.tvm_utils import get_ec2_tvm_flags, update_tvm_convert_map
    flags = get_ec2_tvm_flags()[instance_type]
    update_tvm_convert_map()
    token_ids_shape = (batch_size, seq_length) if layout == 'NT' else (seq_length, batch_size)
    valid_length_shape = (batch_size,)
    if 'bart' in model_name: