假设我有3个tfrecord文件,即neg.tfrecord,pos1.tfrecord,pos2.tfrecord。
我用
dataset = tf.data.TFRecordDataset(tfrecord_file)此代码创建3个数据集对象。
我的批量是400,…
您可以使用 interleave
interleave
filenames = [tfrecord_file1, tfrecord_file2] dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x) dataset = dataset.map(parse_fn) ...
或者你甚至可以尝试并行交错。看到 https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave