NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Replacing nn.Linear w/ te.Linear FP8 convergence issue

viclzhu opened this issue · comments

Hi,

I'm seeing higher losses using te.Linear over nn.Linear directly in transformer models such as Llama which I assume is expected due to the nature of FP8.

However, I don't see a loss increase of the same magnitude replacing the Transformer Layer with te.TransformerLayer. I took a look, and it seems like among the various improvements te.TransformerLayer implements, one looks like a reuse of a singular FP8 scaling factor across multiple gemm invocations within the fused layernorm_mlp. Was this implemented to reduce such cascading precision losses, and are there other notable changes to mitigate this?

I'm seeing in sufficient size models (e.g. Llama 70b), this loss difference is causing a divergence in validation loss leading to convergence issues, whereas without FP8, the validation loss continues to decrease.

I'm wondering if you all have any insight into this or pointers to any references discussing this problem and solutions more in-depth, that would be greatly appreciated.

Thanks!

Hi @viclzhu. There is no recipe difference between using te.Linear alone vs using te.TransformerLayer and so having different behavior in loss curves is not expected. This suggests to me that your problems come from incorrect usage somehow. Could you provide some more information on the software you are using and the specific of the training (like, are you seeing the higher loss issue from the very beginning or after some number of iterations)? One of the typical integration issues we saw was not passing the correct distributed group as fp8_group when running under fp8_autocast context manager.

In general, we do not expect seeing higher losses when using FP8 compared to BF16, loss curves should overlap.

Thanks @ptrendx for the information! Yes, in this case we're running with PyTorch FSDP FULL_SHARD on a HF llama model, with the nn.Linear layers directly replaced with te.Linear (TE v1.2) and seeing higher loss from the beginning.

We're passing the same fp8_group=dist.distributed_c10d._get_default_group() WORLD process group in both tests and are using the HYBRID recipe with amax_history_len=1024, and amax_compute_algo=max.

Do you know if anyone has also tried running tests with this idea and have any results from that? I see that with TE v1.4, there is an example with replacing the Llama transformer layer with the TE transformer layer.

Thanks again!

Hmmm, those settings look ok. Are you able to provide some reproduction script for this issue so that we can investigate?

Yes, here's a script reproducing the issue comparing the output of nn.Linear BF16 to te.Linear FP8 for a single gpu.
Please let me know if you see anything wrong w/ my script or need any more information, thanks!

Environment:

  • PT 2.2.0
  • TE v1.2.1
  • H100 GPU
  • Python 3.10.8

Run command: torchrun --standalone --nnodes=1 --nproc-per-node=1 test_te_linear.py.

# Filename: test_te_linear.py
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import transformer_engine.pytorch as te
from torch.distributed.distributed_c10d import _get_default_group
from transformer_engine.common.recipe import Format, DelayedScaling


def setup():
    seed = 7
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    local_rank = int(os.environ["LOCAL_RANK"])
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)


def check_tol(t1, t2, atol=5e-2, rtol=1e-8):
    # atol from https://github.com/NVIDIA/TransformerEngine/blob/v1.2.1/tests/pytorch/test_numerics.py#L888
    result = torch.allclose(t1, t2, atol=atol, rtol=rtol)
    if not result:
        diff = torch.abs(t1 - t2).flatten()
        m = torch.argmax(diff)
        msg = (
            f"Outputs not close enough in tensor. "
            f"Location of the maximum difference: {m.item()} "
            f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
            f"(diff {diff[m].item()})."
        )
        print(msg)


def main():
    setup()

    all_gpus = dist.new_group(backend="nccl")
    all_gpus_ranks = dist.get_process_group_ranks(all_gpus)
    default_pg_ranks = dist.get_process_group_ranks(
        dist.distributed_c10d._get_default_group()
    )
    print(f"all_ranks: {dist.get_process_group_ranks(all_gpus)}")
    print(
        f"default_pg: {dist.get_process_group_ranks(dist.distributed_c10d._get_default_group())}"
    )
    assert all_gpus_ranks == default_pg_ranks

    # Set base dtype used
    dtype = torch.bfloat16  # [torch.float32, torch.bfloat16]
    device = "cuda"

    # Setup FP8 recipe
    amax_history_len = 32
    fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
    fp8_recipe = DelayedScaling(
        fp8_format=fp8_format,
        amax_history_len=amax_history_len,
        amax_compute_algo="max",
    )

    # FP8 context manager
    ctx_mgr = te.fp8_autocast
    ctx_mgr_kwargs = {
        "enabled": True,
        "fp8_recipe": fp8_recipe,
        "fp8_group": all_gpus,
    }

    hidden_size = 16  # needs to be divisible by 16

    # Instantiate nn Linear layer
    nn_layer = nn.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
    # Instantiate TE Linear layer
    te_layer = te.Linear(hidden_size, hidden_size, params_dtype=dtype, device=device)

    # Copy nn Layer weights into TE Layer
    with torch.no_grad():
        te_layer.weight.copy_(nn_layer.weight)

    assert torch.equal(te_layer.weight, nn_layer.weight)

    print(f"nn_layer.weight: {nn_layer.weight}")
    print(f"te_layer.weight: {te_layer.weight}")

    bs = 8
    x = torch.rand((bs, hidden_size), dtype=dtype).to(device)

    output_nn = nn_layer(x)
    loss_nn = output_nn.sum()
    print(f"nn loss 0: {loss_nn}")

    output_nn = nn_layer(x)
    loss_nn = output_nn.sum()
    print(f"nn loss 1: {loss_nn}")

    # Run forward for FP8 TE layer multiple times
    # Note: looks like 1st forward is different
    # likely due to initial scaling factor
    for i in range(amax_history_len + 2):
        with ctx_mgr(**ctx_mgr_kwargs):
            output_te = te_layer(x)
        loss_te = output_te.sum()
        if i == amax_history_len - 1:
            print(f"***TE loss {i}: {loss_te}")
        else:
            print(f"TE loss {i}: {loss_te}")

    check_tol(output_te, output_nn)


if __name__ == "__main__":
    main()

Example log:

all_ranks: [0]
default_pg: [0]
nn_layer.weight: Parameter containing:
tensor([[ 0.2275, -0.0025,  0.2207,  0.0464,  0.0996, -0.1455,  0.0874, -0.0095,
          0.0669,  0.0366,  0.2061,  0.1104,  0.1270, -0.0354, -0.0116,  0.0493],
        [ 0.1963, -0.0898, -0.1631, -0.0222, -0.0757,  0.2256,  0.0938, -0.1514,
         -0.1416, -0.2109, -0.1768,  0.2246,  0.0359, -0.0145,  0.0204,  0.0977],
        [ 0.0728,  0.0771, -0.1289,  0.0273, -0.1562,  0.1660, -0.1270,  0.1377,
          0.1030, -0.2412, -0.1738, -0.1836, -0.1602, -0.0232,  0.0452, -0.1436],
        [-0.1260,  0.2490,  0.0767, -0.2344,  0.0889,  0.0884, -0.2256,  0.1367,
         -0.0713,  0.1074, -0.2041,  0.2344, -0.0938, -0.0086,  0.0845,  0.2275],
        [-0.2383, -0.0510, -0.2227,  0.0579, -0.0317, -0.1206, -0.1455, -0.1758,
         -0.1123, -0.1865,  0.2256, -0.1494, -0.0112, -0.2217,  0.0003,  0.1924],
        [-0.1621, -0.2021,  0.0869,  0.1982,  0.0679,  0.2246, -0.2100,  0.1826,
         -0.2070,  0.0173, -0.1885, -0.0688, -0.1445, -0.2139,  0.1582, -0.0162],
        [-0.1011, -0.2383,  0.2002,  0.0300,  0.2432, -0.0161,  0.2051, -0.2314,
          0.1621,  0.0544, -0.0603, -0.2344, -0.2461,  0.2451, -0.2324,  0.1582],
        [ 0.1348, -0.0496, -0.1123, -0.0199, -0.0349, -0.0859,  0.2061, -0.1592,
          0.2461,  0.0986,  0.1045, -0.0693, -0.2207,  0.1982, -0.1377, -0.2422],
        [ 0.0153, -0.2295, -0.0315, -0.2402,  0.0630,  0.1729,  0.1787,  0.1270,
         -0.1641, -0.2344,  0.0579, -0.1514, -0.2441, -0.2236,  0.0309, -0.0505],
        [-0.1484,  0.1797, -0.2080, -0.0664, -0.1318, -0.1406,  0.2109,  0.0947,
          0.2021, -0.2080, -0.2266,  0.0214,  0.2334, -0.1426,  0.2031,  0.0903],
        [-0.1719,  0.1836, -0.0933, -0.0220, -0.2178,  0.0942,  0.1138, -0.1235,
         -0.2256,  0.2090, -0.1670, -0.0981,  0.1226, -0.2285,  0.0476, -0.0388],
        [ 0.2480, -0.0618,  0.1289, -0.2393,  0.1187,  0.2217,  0.0776,  0.2061,
          0.1494, -0.0996,  0.1924,  0.1523,  0.0474, -0.0141, -0.0334, -0.1553],
        [ 0.0284, -0.0481, -0.1758,  0.2197,  0.1475, -0.0654, -0.0557,  0.1357,
          0.0273, -0.1328, -0.0231, -0.2217, -0.0598,  0.0962,  0.0986, -0.1211],
        [ 0.2168,  0.1885,  0.2227,  0.0415,  0.1299, -0.0229,  0.1094, -0.0928,
          0.0796,  0.0815, -0.0615, -0.2012,  0.0986,  0.0854, -0.2188, -0.1187],
        [ 0.1050, -0.2109,  0.0532, -0.2188, -0.2256,  0.1572,  0.1533,  0.1348,
          0.1318, -0.1602, -0.0591,  0.2275,  0.1670, -0.0635,  0.1641, -0.1011],
        [ 0.2217, -0.0188,  0.2090, -0.0354, -0.0032, -0.0933,  0.0664,  0.1553,
          0.0425, -0.0913,  0.1167,  0.2217, -0.2031, -0.1436,  0.1729, -0.0330]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
te_layer.weight: Parameter containing:
tensor([[ 0.2275, -0.0025,  0.2207,  0.0464,  0.0996, -0.1455,  0.0874, -0.0095,
          0.0669,  0.0366,  0.2061,  0.1104,  0.1270, -0.0354, -0.0116,  0.0493],
        [ 0.1963, -0.0898, -0.1631, -0.0222, -0.0757,  0.2256,  0.0938, -0.1514,
         -0.1416, -0.2109, -0.1768,  0.2246,  0.0359, -0.0145,  0.0204,  0.0977],
        [ 0.0728,  0.0771, -0.1289,  0.0273, -0.1562,  0.1660, -0.1270,  0.1377,
          0.1030, -0.2412, -0.1738, -0.1836, -0.1602, -0.0232,  0.0452, -0.1436],
        [-0.1260,  0.2490,  0.0767, -0.2344,  0.0889,  0.0884, -0.2256,  0.1367,
         -0.0713,  0.1074, -0.2041,  0.2344, -0.0938, -0.0086,  0.0845,  0.2275],
        [-0.2383, -0.0510, -0.2227,  0.0579, -0.0317, -0.1206, -0.1455, -0.1758,
         -0.1123, -0.1865,  0.2256, -0.1494, -0.0112, -0.2217,  0.0003,  0.1924],
        [-0.1621, -0.2021,  0.0869,  0.1982,  0.0679,  0.2246, -0.2100,  0.1826,
         -0.2070,  0.0173, -0.1885, -0.0688, -0.1445, -0.2139,  0.1582, -0.0162],
        [-0.1011, -0.2383,  0.2002,  0.0300,  0.2432, -0.0161,  0.2051, -0.2314,
          0.1621,  0.0544, -0.0603, -0.2344, -0.2461,  0.2451, -0.2324,  0.1582],
        [ 0.1348, -0.0496, -0.1123, -0.0199, -0.0349, -0.0859,  0.2061, -0.1592,
          0.2461,  0.0986,  0.1045, -0.0693, -0.2207,  0.1982, -0.1377, -0.2422],
        [ 0.0153, -0.2295, -0.0315, -0.2402,  0.0630,  0.1729,  0.1787,  0.1270,
         -0.1641, -0.2344,  0.0579, -0.1514, -0.2441, -0.2236,  0.0309, -0.0505],
        [-0.1484,  0.1797, -0.2080, -0.0664, -0.1318, -0.1406,  0.2109,  0.0947,
          0.2021, -0.2080, -0.2266,  0.0214,  0.2334, -0.1426,  0.2031,  0.0903],
        [-0.1719,  0.1836, -0.0933, -0.0220, -0.2178,  0.0942,  0.1138, -0.1235,
         -0.2256,  0.2090, -0.1670, -0.0981,  0.1226, -0.2285,  0.0476, -0.0388],
        [ 0.2480, -0.0618,  0.1289, -0.2393,  0.1187,  0.2217,  0.0776,  0.2061,
          0.1494, -0.0996,  0.1924,  0.1523,  0.0474, -0.0141, -0.0334, -0.1553],
        [ 0.0284, -0.0481, -0.1758,  0.2197,  0.1475, -0.0654, -0.0557,  0.1357,
          0.0273, -0.1328, -0.0231, -0.2217, -0.0598,  0.0962,  0.0986, -0.1211],
        [ 0.2168,  0.1885,  0.2227,  0.0415,  0.1299, -0.0229,  0.1094, -0.0928,
          0.0796,  0.0815, -0.0615, -0.2012,  0.0986,  0.0854, -0.2188, -0.1187],
        [ 0.1050, -0.2109,  0.0532, -0.2188, -0.2256,  0.1572,  0.1533,  0.1348,
          0.1318, -0.1602, -0.0591,  0.2275,  0.1670, -0.0635,  0.1641, -0.1011],
        [ 0.2217, -0.0188,  0.2090, -0.0354, -0.0032, -0.0933,  0.0664,  0.1553,
          0.0425, -0.0913,  0.1167,  0.2217, -0.2031, -0.1436,  0.1729, -0.0330]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
nn loss 0: 10.375
nn loss 1: 10.375
TE loss 0: -1.484375
TE loss 1: -1.578125
TE loss 2: -1.578125
TE loss 3: -1.578125
TE loss 4: -1.578125
TE loss 5: -1.578125
TE loss 6: -1.578125
TE loss 7: -1.578125
TE loss 8: -1.578125
TE loss 9: -1.578125
TE loss 10: -1.578125
TE loss 11: -1.578125
TE loss 12: -1.578125
TE loss 13: -1.578125
TE loss 14: -1.578125
TE loss 15: -1.578125
TE loss 16: -1.578125
TE loss 17: -1.578125
TE loss 18: -1.578125
TE loss 19: -1.578125
TE loss 20: -1.578125
TE loss 21: -1.578125
TE loss 22: -1.578125
TE loss 23: -1.578125
TE loss 24: -1.578125
TE loss 25: -1.578125
TE loss 26: -1.578125
TE loss 27: -1.578125
TE loss 28: -1.578125
TE loss 29: -1.578125
TE loss 30: -1.578125
***TE loss 31: -1.578125
TE loss 32: -1.578125
TE loss 33: -1.578125
Outputs not close enough in tensor. Location of the maximum difference: 43 with 0.52734375 vs 0.7890625 (diff 0.26171875).

@ptrendx Did you get a chance to review the reproducible script?

Hi @viclzhu and @goswamig. I looked at the posted example - the problem with it is that by default both nn.Linear and te.Linear use bias and the values of that were not copied between the modules.

When I changed the script to also copy the bias,

    with torch.no_grad():
        te_layer.weight.copy_(nn_layer.weight)
        te_layer.bias.copy_(nn_layer.bias)

    assert torch.equal(te_layer.weight, nn_layer.weight)
    assert torch.equal(te_layer.bias, nn_layer.bias)

I got those results:

  • with BF16 (enabled=False in the ctx_mgr_kwargs):
nn loss 0: 10.375
nn loss 1: 10.375
TE loss 0: 10.375
TE loss 1: 10.375
TE loss 2: 10.375
TE loss 3: 10.375
TE loss 4: 10.375
TE loss 5: 10.375
TE loss 6: 10.375
TE loss 7: 10.375
TE loss 8: 10.375
TE loss 9: 10.375
TE loss 10: 10.375
TE loss 11: 10.375
TE loss 12: 10.375
TE loss 13: 10.375
TE loss 14: 10.375
TE loss 15: 10.375
TE loss 16: 10.375
TE loss 17: 10.375
TE loss 18: 10.375
TE loss 19: 10.375
TE loss 20: 10.375
TE loss 21: 10.375
TE loss 22: 10.375
TE loss 23: 10.375
TE loss 24: 10.375
TE loss 25: 10.375
TE loss 26: 10.375
TE loss 27: 10.375
TE loss 28: 10.375
TE loss 29: 10.375
TE loss 30: 10.375
***TE loss 31: 10.375
TE loss 32: 10.375
TE loss 33: 10.375
  • and with FP8 (enabled=True):
nn loss 0: 10.375                                                                                                                                                                                                                             
nn loss 1: 10.375                                                                                                                                                                                                                             
TE loss 0: 10.6875                                                                                                                                                                                                                            
TE loss 1: 10.625                                                                                                                                                                                                                             
TE loss 2: 10.625                                                                                                                                                                                                                             
TE loss 3: 10.625                                                                                                                                                                                                                             
TE loss 4: 10.625                                                                                                                                                                                                                             
TE loss 5: 10.625                                                                                                                                                                                                                             
TE loss 6: 10.625                                                                                                                                                                                                                             
TE loss 7: 10.625                                                                                                                                                                                                                             
TE loss 8: 10.625                                                                                                                                                                                                                             
TE loss 9: 10.625                                                                                                                                                                                                                             
TE loss 10: 10.625                                                                                                                                                                                                                            
TE loss 11: 10.625                                                                                                                                                                                                                            
TE loss 12: 10.625                                                                                                                                                                                                                            
TE loss 13: 10.625
TE loss 14: 10.625
TE loss 15: 10.625
TE loss 16: 10.625
TE loss 17: 10.625
TE loss 18: 10.625
TE loss 19: 10.625
TE loss 20: 10.625
TE loss 21: 10.625
TE loss 22: 10.625
TE loss 23: 10.625
TE loss 24: 10.625
TE loss 25: 10.625
TE loss 26: 10.625
TE loss 27: 10.625
TE loss 28: 10.625
TE loss 29: 10.625
TE loss 30: 10.625
***TE loss 31: 10.625
TE loss 32: 10.625
TE loss 33: 10.625

I also no longer get the Outputs not close enough assertion error.

Oh awesome, thanks for the catch and sanity check!

I'll look closer in my implementation then, something else must be going wrong.

I actually just re-ran the script with your bias fix in my environment along with changing the input x generation updated from rand() to randn(), and I see a greater difference in loss than the set 5e-2 tolerance.

Additionally, I see the losses are still different in your output just smaller 10.625-10.375=0.25.
So I can confirm that this loss difference is expected?

nn_layer.weight: Parameter containing:
tensor([[ 0.2275, -0.0025,  0.2207,  0.0464,  0.0996, -0.1455,  0.0874, -0.0095,
          0.0669,  0.0366,  0.2061,  0.1104,  0.1270, -0.0354, -0.0116,  0.0493],
        [ 0.1963, -0.0898, -0.1631, -0.0222, -0.0757,  0.2256,  0.0938, -0.1514,
         -0.1416, -0.2109, -0.1768,  0.2246,  0.0359, -0.0145,  0.0204,  0.0977],
        [ 0.0728,  0.0771, -0.1289,  0.0273, -0.1562,  0.1660, -0.1270,  0.1377,
          0.1030, -0.2412, -0.1738, -0.1836, -0.1602, -0.0232,  0.0452, -0.1436],
        [-0.1260,  0.2490,  0.0767, -0.2344,  0.0889,  0.0884, -0.2256,  0.1367,
         -0.0713,  0.1074, -0.2041,  0.2344, -0.0938, -0.0086,  0.0845,  0.2275],
        [-0.2383, -0.0510, -0.2227,  0.0579, -0.0317, -0.1206, -0.1455, -0.1758,
         -0.1123, -0.1865,  0.2256, -0.1494, -0.0112, -0.2217,  0.0003,  0.1924],
        [-0.1621, -0.2021,  0.0869,  0.1982,  0.0679,  0.2246, -0.2100,  0.1826,
         -0.2070,  0.0173, -0.1885, -0.0688, -0.1445, -0.2139,  0.1582, -0.0162],
        [-0.1011, -0.2383,  0.2002,  0.0300,  0.2432, -0.0161,  0.2051, -0.2314,
          0.1621,  0.0544, -0.0603, -0.2344, -0.2461,  0.2451, -0.2324,  0.1582],
        [ 0.1348, -0.0496, -0.1123, -0.0199, -0.0349, -0.0859,  0.2061, -0.1592,
          0.2461,  0.0986,  0.1045, -0.0693, -0.2207,  0.1982, -0.1377, -0.2422],
        [ 0.0153, -0.2295, -0.0315, -0.2402,  0.0630,  0.1729,  0.1787,  0.1270,
         -0.1641, -0.2344,  0.0579, -0.1514, -0.2441, -0.2236,  0.0309, -0.0505],
        [-0.1484,  0.1797, -0.2080, -0.0664, -0.1318, -0.1406,  0.2109,  0.0947,
          0.2021, -0.2080, -0.2266,  0.0214,  0.2334, -0.1426,  0.2031,  0.0903],
        [-0.1719,  0.1836, -0.0933, -0.0220, -0.2178,  0.0942,  0.1138, -0.1235,
         -0.2256,  0.2090, -0.1670, -0.0981,  0.1226, -0.2285,  0.0476, -0.0388],
        [ 0.2480, -0.0618,  0.1289, -0.2393,  0.1187,  0.2217,  0.0776,  0.2061,
          0.1494, -0.0996,  0.1924,  0.1523,  0.0474, -0.0141, -0.0334, -0.1553],
        [ 0.0284, -0.0481, -0.1758,  0.2197,  0.1475, -0.0654, -0.0557,  0.1357,
          0.0273, -0.1328, -0.0231, -0.2217, -0.0598,  0.0962,  0.0986, -0.1211],
        [ 0.2168,  0.1885,  0.2227,  0.0415,  0.1299, -0.0229,  0.1094, -0.0928,
          0.0796,  0.0815, -0.0615, -0.2012,  0.0986,  0.0854, -0.2188, -0.1187],
        [ 0.1050, -0.2109,  0.0532, -0.2188, -0.2256,  0.1572,  0.1533,  0.1348,
          0.1318, -0.1602, -0.0591,  0.2275,  0.1670, -0.0635,  0.1641, -0.1011],
        [ 0.2217, -0.0188,  0.2090, -0.0354, -0.0032, -0.0933,  0.0664,  0.1553,
          0.0425, -0.0913,  0.1167,  0.2217, -0.2031, -0.1436,  0.1729, -0.0330]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
te_layer.weight: Parameter containing:
tensor([[ 0.2275, -0.0025,  0.2207,  0.0464,  0.0996, -0.1455,  0.0874, -0.0095,
          0.0669,  0.0366,  0.2061,  0.1104,  0.1270, -0.0354, -0.0116,  0.0493],
        [ 0.1963, -0.0898, -0.1631, -0.0222, -0.0757,  0.2256,  0.0938, -0.1514,
         -0.1416, -0.2109, -0.1768,  0.2246,  0.0359, -0.0145,  0.0204,  0.0977],
        [ 0.0728,  0.0771, -0.1289,  0.0273, -0.1562,  0.1660, -0.1270,  0.1377,
          0.1030, -0.2412, -0.1738, -0.1836, -0.1602, -0.0232,  0.0452, -0.1436],
        [-0.1260,  0.2490,  0.0767, -0.2344,  0.0889,  0.0884, -0.2256,  0.1367,
         -0.0713,  0.1074, -0.2041,  0.2344, -0.0938, -0.0086,  0.0845,  0.2275],
        [-0.2383, -0.0510, -0.2227,  0.0579, -0.0317, -0.1206, -0.1455, -0.1758,
         -0.1123, -0.1865,  0.2256, -0.1494, -0.0112, -0.2217,  0.0003,  0.1924],
        [-0.1621, -0.2021,  0.0869,  0.1982,  0.0679,  0.2246, -0.2100,  0.1826,
         -0.2070,  0.0173, -0.1885, -0.0688, -0.1445, -0.2139,  0.1582, -0.0162],
        [-0.1011, -0.2383,  0.2002,  0.0300,  0.2432, -0.0161,  0.2051, -0.2314,
          0.1621,  0.0544, -0.0603, -0.2344, -0.2461,  0.2451, -0.2324,  0.1582],
        [ 0.1348, -0.0496, -0.1123, -0.0199, -0.0349, -0.0859,  0.2061, -0.1592,
          0.2461,  0.0986,  0.1045, -0.0693, -0.2207,  0.1982, -0.1377, -0.2422],
        [ 0.0153, -0.2295, -0.0315, -0.2402,  0.0630,  0.1729,  0.1787,  0.1270,
         -0.1641, -0.2344,  0.0579, -0.1514, -0.2441, -0.2236,  0.0309, -0.0505],
        [-0.1484,  0.1797, -0.2080, -0.0664, -0.1318, -0.1406,  0.2109,  0.0947,
          0.2021, -0.2080, -0.2266,  0.0214,  0.2334, -0.1426,  0.2031,  0.0903],
        [-0.1719,  0.1836, -0.0933, -0.0220, -0.2178,  0.0942,  0.1138, -0.1235,
         -0.2256,  0.2090, -0.1670, -0.0981,  0.1226, -0.2285,  0.0476, -0.0388],
        [ 0.2480, -0.0618,  0.1289, -0.2393,  0.1187,  0.2217,  0.0776,  0.2061,
          0.1494, -0.0996,  0.1924,  0.1523,  0.0474, -0.0141, -0.0334, -0.1553],
        [ 0.0284, -0.0481, -0.1758,  0.2197,  0.1475, -0.0654, -0.0557,  0.1357,
          0.0273, -0.1328, -0.0231, -0.2217, -0.0598,  0.0962,  0.0986, -0.1211],
        [ 0.2168,  0.1885,  0.2227,  0.0415,  0.1299, -0.0229,  0.1094, -0.0928,
          0.0796,  0.0815, -0.0615, -0.2012,  0.0986,  0.0854, -0.2188, -0.1187],
        [ 0.1050, -0.2109,  0.0532, -0.2188, -0.2256,  0.1572,  0.1533,  0.1348,
          0.1318, -0.1602, -0.0591,  0.2275,  0.1670, -0.0635,  0.1641, -0.1011],
        [ 0.2217, -0.0188,  0.2090, -0.0354, -0.0032, -0.0933,  0.0664,  0.1553,
          0.0425, -0.0913,  0.1167,  0.2217, -0.2031, -0.1436,  0.1729, -0.0330]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
nn_layer.bias: Parameter containing:
tensor([-0.0466,  0.1934,  0.2148,  0.0437, -0.0957,  0.0388, -0.1553,  0.1553,
         0.2422,  0.1157,  0.2021,  0.2480,  0.0718,  0.0835,  0.1138,  0.0918],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
te_layer.bias: Parameter containing:
tensor([-0.0466,  0.1934,  0.2148,  0.0437, -0.0957,  0.0388, -0.1553,  0.1553,
         0.2422,  0.1157,  0.2021,  0.2480,  0.0718,  0.0835,  0.1138,  0.0918],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
nn loss 0: 14.5625
nn loss 1: 14.5625
TE loss 0: 14.1875
TE loss 1: 14.75
TE loss 2: 14.75
TE loss 3: 14.75
TE loss 4: 14.75
TE loss 5: 14.75
TE loss 6: 14.75
TE loss 7: 14.75
TE loss 8: 14.75
TE loss 9: 14.75
TE loss 10: 14.75
TE loss 11: 14.75
TE loss 12: 14.75
TE loss 13: 14.75
TE loss 14: 14.75
TE loss 15: 14.75
TE loss 16: 14.75
TE loss 17: 14.75
TE loss 18: 14.75
TE loss 19: 14.75
TE loss 20: 14.75
TE loss 21: 14.75
TE loss 22: 14.75
TE loss 23: 14.75
TE loss 24: 14.75
TE loss 25: 14.75
TE loss 26: 14.75
TE loss 27: 14.75
TE loss 28: 14.75
TE loss 29: 14.75
TE loss 30: 14.75
***TE loss 31: 14.75
TE loss 32: 14.75
TE loss 33: 14.75
Outputs not close enough in tensor. Location of the maximum difference: 8 with -1.203125 vs -1.1171875 (diff 0.0859375).