我发现如何正确地确定如何正确定义mxnet网络以便我可以将此模型序列化/转换为json文件。
管道由CNN + biLSTM + CTC组成。
我现在必须使用……
Gluon中的模型可以输出到json的要求很少:
它需要是可混合的,这意味着每个子块也应该是可混合的,并且模型在两种模式下都可以工作
应初始化所有参数。由于Gluon使用延迟参数初始化,这意味着您应该至少进行一次前向传递,然后才能保存模型。
我为你的代码做了一些修复,并在我需要的时候引入了新的常量。最重要的变化是:
如果可以避免使用拆分,请不要使用拆分,因为它返回NDArrays列表。使用reshape,它也可以与Symbol一起使用。
从1.0.0版本的MXNet开始,LSTM也可以混合使用,因此您可以将其包装在HybridBlock中而不仅仅是块中。
使用HybridSequential。
以下是调整后的代码,底部有一个示例,如何保存模型以及如何加载模型。您可以在中找到更多信息 本教程 。
import mxnet as mx from mxnet import gluon from mxnet import nd BATCH_SIZE = 1 CHANNELS = 100 ALPHABET_SIZE = 1000 NUM_HIDDEN = 200 NUM_CLASSES = 13550 NUM_LSTM_LAYER = 1 p_dropout = 0.5 SEQ_LEN = 32 HEIGHT = 100 WIDTH = 100 def get_featurizer(): featurizer = gluon.nn.HybridSequential() featurizer.add( gluon.nn.Conv2D(kernel_size=(3, 3), padding=(1, 1), channels=32, activation="relu")) featurizer.add(gluon.nn.BatchNorm()) return featurizer class EncoderLayer(gluon.HybridBlock): def __init__(self, **kwargs): super(EncoderLayer, self).__init__(**kwargs) with self.name_scope(): self.lstm = mx.gluon.rnn.LSTM(NUM_HIDDEN, NUM_LSTM_LAYER, bidirectional=True) def hybrid_forward(self, F, x): x = x.transpose((0, 3, 1, 2)) x = x.flatten() x = x.reshape(shape=(SEQ_LEN, -1, CHANNELS)) #x.split(num_outputs=SEQ_LEN, axis=1) # (SEQ_LEN, N, CHANNELS) x = self.lstm(x) x = x.transpose((1, 0, 2)) # (N, SEQ_LEN, HIDDEN_UNITS) return x def get_encoder(): encoder = gluon.nn.HybridSequential() encoder.add(EncoderLayer()) encoder.add(gluon.nn.Dropout(p_dropout)) return encoder def get_decoder(): decoder = mx.gluon.nn.Dense(units=ALPHABET_SIZE, flatten=False) return decoder def get_net(): net = gluon.nn.HybridSequential() with net.name_scope(): net.add(get_featurizer()) net.add(get_encoder()) net.add(get_decoder()) return net if __name__ == '__main__': net = get_net() net.initialize() net.hybridize() fake_data = mx.random.uniform(shape=(BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)) out = net(fake_data) net.export("mymodel") deserialized_net = gluon.nn.SymbolBlock.imports("mymodel-symbol.json", ['data'], "mymodel-0000.params", ctx=mx.cpu()) out2 = deserialized_net(fake_data) # just to check that we get the same results assert (out - out2).sum().asscalar() == 0