FC = mx.sym.FullyConnected(data = x_3,flatten = False,num_hidden = n_class)x = mx.sym.softmax(data = FC)
sm_label = mx.sym.Reshape(data = label,shape =( - 1,))sm_label = mx.sym.Cast(data = sm_label,dtype =’ …
MXNet repo有一个WarpCTC示例 这里 。您可以使用运行培训 python lstm_ocr_train.py --gpu 1 --num_proc 4 --loss warpctc font/Ubuntu-M.ttf 。在示例中,以下是WarpCTC运算符使用的预测和标签的形状:
python lstm_ocr_train.py --gpu 1 --num_proc 4 --loss warpctc font/Ubuntu-M.ttf
Prediction is (10240, 11) Label is (512,) label_length: 4 input_length: 80 batch_size = 128 seq_length = 80
在上述情况下,
按照示例的说明,我建议在您的情况下调用具有预测形状=(1120,27),标签形状=(672,),label_length = 21,input_length = 35的WarpCTC。