kaixih / xla_fp8_testing

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

xla_fp8_testing

Run the benchmark

With the fp32 precision

python benchmark.py

With the fp32 + fp8 precision

TF_CPP_VMODULE=gemm_rewriter=1 TF_XLA_FLAGS="--tf_xla_auto_jit=2" \
    XLA_FLAGS="--xla_gpu_enable_cublaslt=true" python benchmark.py --fp8

With the mixed precision

python benchmark.py --mixed

With the mixed + fp8 precision

TF_CPP_VMODULE=gemm_rewriter=1 TF_XLA_FLAGS="--tf_xla_auto_jit=2" \
    XLA_FLAGS="--xla_gpu_enable_cublaslt=true" python benchmark.py --fp8 --mixed

Dump HLO graphs

TF_DUMP_GRAPH_PREFIX=/tmp/generated \
TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
XLA_FLAGS="--xla_dump_hlo_as_dot --xla_dump_to=/tmp/generated \
--xla_gpu_enable_cublaslt=true --xla_gpu_simplify_all_fp_conversions=false \
--xla_dump_hlo_pass_re=.*" python benchmark.py --fp8 --mixed

About


Languages

Language:Python 100.0%