由于批量大小为1,因此使用较低的学习率(如1e-4)或增加批量大小。
我建议批量大小为16或更大。
编辑:要创建一批数据,您可以执行以下操作。
N = input.shape[0] #know the total size/samples in input for i in range(n_epochs): # this is to shuffle data indices = torch.randperm(N) for idx in range(0, N, batch_size): batch_input = input[idx:idx+batch_size] # this will get you input of size batch_size # do whatever you want with the batch_input # ....