我有一个如下所示的数据集。这是第一项是用户ID,后跟用户点击的项目集。
0 24104 27359 66840 24104 273591 16742 31529 ……
正如@Jatentaki建议的那样,我编写了自定义整理功能,并且工作正常。
def get_max_length(x): return len(max(x, key=len)) def pad_sequence(seq): def _pad(_it, _max_len): return [0] * (_max_len - len(_it)) + _it return [_pad(it, get_max_length(seq)) for it in seq] def custom_collate(batch): transposed = zip(*batch) lst = [] for samples in transposed: if isinstance(samples[0], int): lst.append(torch.LongTensor(samples)) elif isinstance(samples[0], float): lst.append(torch.DoubleTensor(samples)) elif isinstance(samples[0], collections.Sequence): lst.append(torch.LongTensor(pad_sequence(samples))) return lst stream_dataset = StreamDataset(data_path) stream_data_loader = torch.utils.data.dataloader.DataLoader(dataset=stream_dataset, batch_size=batch_size, collate_fn=custom_collate, shuffle=False)
那么你如何处理样品长度不同的事实呢? torch.utils.data.DataLoader 有个 collate_fn 用于将样本列表转换为批处理的参数。通过 默认 它确实 这个 到列表。你可以自己写 collate_fn ,例如 0 -pads输入,将其截断为某个预定义的长度或应用您选择的任何其他操作。
torch.utils.data.DataLoader
collate_fn
0