我想按标签中的数值对数据集进行排序。
有没有来自pytorch的功能来有效地处理这个问题?
我的数据集类型()来自:
< class’torchvision ….
没有通用的方法可以有效地执行此操作,因为数据集类只实现了一个 __getitem__ 和 __len__ 方法,并且不一定具有关于标签的任何“存储”信息。
__getitem__
__len__
在的情况下 MNIST数据集 但是,您可以从标签列表中对数据集进行排序。
例如,当您要列出标签为5的索引时。
mnist = torchvision.datasets.mnist.MNIST("/") labels = mnist.train_labels fives = (labels == 5).nonzero()