我找到了一个解决方案,叫做 tf.data.Dataset.from_generator 。
tf.data.Dataset.from_generator
这基本上可以解决问题:
def generate_train_data(self, batch_size: int) -> typing.Iterable[typing.Tuple[np.ndarray, np.ndarray]]: row_id = 0 features = self.get_features() targets = self.get_targets() test_amount = self.get_test_data_amount() while row_id < features.shape[0]: limit = min(features.shape[0] - test_amount, row_id + batch_size) feature_batch = features[row_id:limit, :] target_batch = targets[row_id:limit, :] yield (feature_batch, target_batch) del feature_batch, target_batch row_id += batch_size
并创造了 tf.data.Dataset 就像是:
tf.data.Dataset
train_data = tf.data.Dataset.from_generator( data.generate_train_data, output_types=(tf.bool, tf.bool), output_shapes=( (None, data.get_feature_amount()), (None, data.get_target_amount()), ), args=(batch_size,), ).repeat()
这当然不会改变数据,但改造起来非常容易......