这是我提出的解决方案。
public static void main(String[] args) throws IOException, InterruptedException { CSVDataSet dataSet = new CSVDataSet(new File("./train.csv")); CSVDataSetIterator trainingSetIterator = new CSVDataSetIterator(dataSet, dataSet.size()); MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() .weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1)).iterations(1150) .learningRate(1).seed(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD) .list(2) .backprop(true).pretrain(false) .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).updater(Updater.SGD).build()) .layer(1, new OutputLayer.Builder().nIn(3).nOut(1).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(configuration); network.setListeners(new HistogramIterationListener(10), new ScoreIterationListener(100)); network.init(); long start = System.currentTimeMillis(); network.fit(trainingSetIterator); System.out.println(System.currentTimeMillis() - start); try(DataOutputStream dos = new DataOutputStream(Files.newOutputStream(Paths.get("xor-coefficients.bin")))){ Nd4j.write(network.params(), dos); } FileUtils.write(new File("xor-network-conf.json"), network.getLayerWiseConfigurations().toJson()); }
去测试:
MultiLayerConfiguration configuration = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File("xor-network-conf.json"))); try (DataInputStream dis = new DataInputStream(new FileInputStream("xor-coefficients.bin"))) { INDArray parameters = Nd4j.read(dis); MultiLayerNetwork network = new MultiLayerNetwork(configuration, parameters); network.init(); List<INDArray> inputs = ImmutableList.of(Nd4j.create(new double[]{1, 0}), Nd4j.create(new double[]{0, 1}), Nd4j.create(new double[]{1, 1}), Nd4j.create(new double[]{0, 0})); List<INDArray> networkResults = inputs.stream().map(network::output).collect(toList()); System.out.println(networkResults); } }
有训练数据:
0,1,1
1,0,1
1,1,0
0,0,0
我相信直接来自他们的git存储库有一个XOR示例!
代码已有详细记录,您可以在此处找到存储库: https://github.com/deeplearning4j/dl4j-0.4-examples.git