项目作者: santhoshkolloju

项目描述 :
tfrecord
高级语言: Python
项目地址: git://github.com/santhoshkolloju/TfRecordPytorch.git
创建时间: 2019-11-28T15:06:11Z
项目社区:https://github.com/santhoshkolloju/TfRecordPytorch

开源协议:

下载


TfRecordPytorch

  1. Usage:
  2. from tfrecord_pytorch import TFRecordPytorch
  3. file_name = "train.tfrecord"
  4. col_mapping={
  5. "input_ids":tf.io.VarLenFeature(tf.int64),
  6. "label_ids":tf.io.VarLenFeature(tf.int64)
  7. }
  8. *Note pytorch Iterable Dataset doesnt allow shuffle in Data Loader
  9. dataset = TFRecordPytorch(file_name,col_mapping,shuffle=True,buffer_size=10000)
  10. loader = torch.utils.DataLoader(dataset,batch_size=4,collate_fn=pad_and_sort)
  11. iterator = iter(loader)
  12. print(next(iterator))
  13.