你可以这样做:
input_ = Input((k, model.input.shape[1])) input_as_list = Lambda(lambda x: tf.unstack(x, axis=1))(input_) model_outputs = [model(x) for x in input_as_list] model_outputs = [Lambda(lambda x: K.expand_dims(x, axis=1))(y) for y in model_outputs] concat_output = Concatenate(axis=1)(model_outputs) new_model = Model(input_, concat_output)