dmlc / dlpack

common in-memory tensor structure

Home Page:https://dmlc.github.io/dlpack/latest

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[RFC] Add bfloat16 data type support

ElaineBao opened this issue · comments

I'm arising this issue for dlpack bfloat16 support.

Bfloat16 is a popular 16-bit floating point format for machine learning, supported by multiple hardware, e.g. TPU. Compared to fp16, bfloat16 has a greater dynamic range, so it's useful for things like gradients that can be outside the dynamic range of fp16. Compared to fp32, Using bfloat16 reduces the size of data in memory and allows larger models to fit in the same amount of memory. So there are many advantages of bfloat16, it's a trend for different frameworks to support bfloat16. Tensorflow has already supported bfloat16 data type. And we are now supporting bfloat16 in MXNet.

Dlpack is an open in-memory tensor structure for sharing tensors among deep learning frameworks. Supporting bfloat16 can make dlpack more flexible and integrated in data sharing between different frameworks.

Current status of dlpack and bfloat16 support in Frameworks:

1. Pytorch:

Pytorch has two interfaces for converting data from/to dlpack format. tsor = torch.utils.dlpack.from_dlpack(dl) converts data from dlpack-defined tensor to pytorch-defined tensor. dl = torch.utils.dlpack.to_dlpack(tsor) converts data from pytorch-defined tensor to dlpack-defined tensor. And when using to_dlpack function, getDLDataType is used to check the data types that have been enabled for data sharing in dlpack:

DLDataType getDLDataType(const Tensor& t) {
  DLDataType dtype;
  dtype.lanes = 1;
  dtype.bits = t.element_size() * 8;
  switch (t.scalar_type()) {
    case ScalarType::Byte:
      dtype.code = DLDataTypeCode::kDLUInt;
      break;
    case ScalarType::Char:
      dtype.code = DLDataTypeCode::kDLInt;
      break;
    case ScalarType::Double:
      dtype.code = DLDataTypeCode::kDLFloat;
      break;
    case ScalarType::Float:
     dtype.code = DLDataTypeCode::kDLFloat;
      break;
    case …
    case ScalarType::BFloat16:
      throw std::logic_error("BFloat16 is not supported by dlpack");
      break;

For now as dlpack has not supported bfloat16 yet, getDLDataType throws an error when encountering bfloat16 data type. Once dlpack supports bfloat16 data type, this code can be easily changed.

2. MXNet:

Similar to pytorch, mxnet also has arr = mx.nd.from_dlpack(dl), dl = mx.nd.to_dlpack_for_read(arr) and dl = mx.nd.to_dlpack_for_write(arr) for dlpack/mxnet data sharing. Also DTypeTransform is used to check the data types.

static DLDataType DTypeTransform(int type_flag) {
    switch (type_flag) {
      case mshadow::kFloat32: return DLDataType{kDLFloat, 32, 1};
      case mshadow::kFloat64: return DLDataType{kDLFloat, 64, 1};
      case mshadow::kFloat16: return DLDataType{kDLFloat, 16, 1};
      case mshadow::kBfloat16: return DLDataType{kDLBfloat, 16, 1}; // add this line to support bfloat16
      case ......
      }
    }

add bfloat16 data type support in this function and then we can use this data type as inputs, params or outputs for operator computation.

3. Tensorflow:

Tensorflow haven't support dlpack yet, but there's a discussion on it (issue). Tensorflow has already support bfloat16.

As discussed above, we can see that bfloat16 has a good support in various frameworks. On the other hand, dlpack is also becoming more and more popular. So it will be really great if dlpack can have bfloat16 data type support.

Proposal for supporting bfloat16 in dlpack:

Here is a draft proposal for supporting bfloat16 in dlpack. the modification in dlpack will be very simple, just add one single line in DLDataTypeCode:

typedef enum {
  kDLInt = 0U,
  kDLUInt = 1U,
  kDLFloat = 2U,
  kDLBfloat = 3U, // add this line to support bfloat16
} DLDataTypeCode;

And it's done.

Do you have any ideas? Thank you @soumith @piiswrong @Yangqing @naibaf7 @bhack @edgarriba @tqchen @prigoyal @zdevito @pengzhao-intel @ZhennanQin

@tqchen and all, any comments for this RFC. If no, we will consider it has been accepted :)

Given that many of the related folks could be in holiday. We will need to wait for a longer period of time, then hold a vote to include it in.

Please also add more descriptions to the RFC to talk about the background and the proposed change.

Given that many of the related folks could be in holiday. We will need to wait for a longer period of time, then hold a vote to include it in.

Please also add more descriptions to the RFC to talk about the background and the proposed change.

It makes sense. Sorry, I forgot the holiday of US :(
Yixin is WIP to give a whole picture of BF16 in the industry and our plan.

@tqchen & all the description is updated :)

DLPack is mainly about exchange instead of a single framework. Please do update the RFC to include discussions about its implications to the related frameworks, namely pytorch, tensorflow, mxnet and other frameworks that uses dlpack

Hi, @tqchen & all, the description is updated, including discussions about frameworks. :)

@tqchen @szha could you take a review?
What's the next step to make the things going forward?

Thanks for the RFC. There are two things that I would like to bring up.

  • First of all, bfloat itself is only defined for 16 bits, leaving other bits undefined, we should have some thoughts about what that
  • Compatiblity in terms of downstream frameworks. TVM uses type code 3 for Opaque handle type, which can be defined for frameworks. While that type code has not yet been upstreamed. It would be great to skip code 3 and use 4 instead for the blfloat type

Hi, @tqchen & all,
For the first concern, as indicated by Wikipedia

The bfloat16 (Brain Floating Point) floating-point format is a computer number format occupying 16 bits in computer memory; it represents a wide dynamic range of numeric values by using a floating radix point. This format is a truncated (16-bit) version of the 32-bit IEEE 754 single-precision floating-point format (binary32) with the intent of accelerating machine learning and near-sensor computing.

So it has been a common sense in the industry that bfloat is only defined for 16 bits.

And for the second concern, OK, we'll skip code 3 and use 4 instead for the bfloat type :)

Hi, I've opened the pr #48 for bfloat16 data type support.
Feel free to let me know if there are some updates.

@tqchen any suggestion? we're working on PR MXNet code now apache/mxnet#17265 so we need this improvement soon.

let us wait for another week, if there is no further comments, we can merge the pR in

let us wait for another week, if there is no further comments, we can merge the pR in

Thanks a lot!

Hi, @tqchen & all, any more suggestions? If not, can you please help to merge the PR? We need this improvement soon, thanks :)

The bfloat PR is now been merged

The bfloat PR is now been merged

Thanks, @tqchen :)