是的,您可以在MXNet中的多个GPU(和多台机器)上训练RNN。我刚刚确认下面的代码适用于具有4个GPU的机器上的MXNet v1.3.0。
import mxnet as mx GPU_COUNT = 4 context = [mx.gpu(i) for i in range(GPU_COUNT)] model = mx.gluon.rnn.RNN(hidden_size=10, num_layers=1) model.collect_params().initialize(mx.init.Xavier(), ctx=context)
您可能需要仔细检查是否有任何内容覆盖您的上下文,因为看起来您在这里使用的是空上下文(即 string[] )。尝试同时在多个上下文中创建数组时,您也会遇到类似的错误。
string[]
mx.nd.zeros(shape=(10,10), ctx=context)
给出以下错误(注意上下文包含多个设备):
MXNetError: [20:15:03] include/mxnet/./base.h:388: Invalid context string [gpu(0), gpu(1), gpu(2), gpu(3)]