我想出了这个问题。实际上,mxnet所期望的形状取决于数据集(它实际上取决于数据集中的最大值)。训练在单个gpu盒上进行,并具有整个数据集。然而,预测适用于单节点设置,因为它具有训练中使用的所有数据。但是,当使用多节点集群时,数据集会被拆分,这使得每个节点的max-value不同。这导致了错误。
我现在已经制作了与数据集无关的预期形状,并且此错误不再发生。我希望这能澄清事情。