santhoshkolloju / TfRecordPytorch

tfrecord

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TfRecordPytorch

Usage:
from tfrecord_pytorch import TFRecordPytorch
file_name = "train.tfrecord"
col_mapping={
"input_ids":tf.io.VarLenFeature(tf.int64),
"label_ids":tf.io.VarLenFeature(tf.int64)
}
*Note pytorch Iterable Dataset doesnt allow shuffle in Data Loader
dataset = TFRecordPytorch(file_name,col_mapping,shuffle=True,buffer_size=10000)
loader = torch.utils.DataLoader(dataset,batch_size=4,collate_fn=pad_and_sort)
iterator = iter(loader)
print(next(iterator))