/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.neural;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.function.Consumer;
import net.zomis.machlearn.neural.LearningData;
import net.zomis.machlearn.neural.NeuralNetwork;
import net.zomis.machlearn.neural.Neuron;
import net.zomis.machlearn.neural.NeuronLayer;
import net.zomis.machlearn.neural.NeuronLink;

public class Backpropagation {
    private final double learningRate;
    private final int iterations;
    private int logRate = Integer.MAX_VALUE;
    private boolean costFunctionCheck;

    public Backpropagation(double learningRate, int iterations) {
        if (learningRate <= 0.0 || learningRate > 1.0) {
            throw new IllegalArgumentException("Learning rate must be in range (0..1]");
        }
        this.learningRate = learningRate;
        this.iterations = iterations;
    }

    public NeuralNetwork backPropagationLearning(Collection<LearningData> examples, NeuralNetwork network) {
        return this.backPropagationLearning(examples, network, Backpropagation.initializeRandom(new Random()));
    }

    public static Consumer<NeuralNetwork> initializeRandom(Random random) {
        return network -> network.links().forEach(it -> it.setWeight(random.nextDouble() / 2.0 - 0.25));
    }

    public static Consumer<NeuralNetwork> initializeLayerSpecific(Random random) {
        return network -> {
            for (int l = 1; l < network.getLayerCount(); ++l) {
                NeuronLayer layer = network.getLayer(l);
                int layerIn = network.getLayer(l - 1).size();
                int layerOut = network.getLayer(l).size();
                double epsilon = Math.sqrt(6.0) / Math.sqrt(layerIn + layerOut);
                double epsilon2 = epsilon * 2.0;
                for (Neuron neuron : layer.getNeurons()) {
                    for (NeuronLink link : neuron.getInputs()) {
                        double rand = random.nextDouble() * epsilon2 - epsilon;
                        link.setWeight(rand);
                    }
                }
            }
        };
    }

    public NeuralNetwork stohasticGradientDescent(List<LearningData> examples, NeuralNetwork network, Consumer<NeuralNetwork> weightsInitialization, Random random, int targetIterations) {
        int[] layerSizes = network.getLayers().stream().mapToInt(it -> it.size()).toArray();
        if (weightsInitialization != null) {
            weightsInitialization.accept(network);
        }
        Collections.shuffle(examples, random);
        double[][] deltas = new double[network.getLayerCount() - 1][];
        for (int layeri = 0; layeri < network.getLayerCount() - 1; ++layeri) {
            NeuronLayer layer = network.getLayer(layeri + 1);
            deltas[layeri] = new double[layer.size()];
        }
        double previousCost = Double.MAX_VALUE;
        int iterations = 0;
        do {
            Backpropagation.zero(deltas);
            ++iterations;
            double cost = 0.0;
            for (LearningData data : examples) {
                int layerIndex;
                int neuronIndexInLayer = 0;
                for (Neuron neuron : network.getInputLayer()) {
                    neuron.setOutput(data.getInput(neuronIndexInLayer++));
                }
                for (int layerIndex2 = 1; layerIndex2 < network.getLayerCount(); ++layerIndex2) {
                    NeuronLayer layer = network.getLayer(layerIndex2);
                    Iterator<Neuron> iterator = layer.iterator();
                    while (iterator.hasNext()) {
                        Neuron node = iterator.next();
                        node.process();
                    }
                }
                double[] expectedOutput = data.getOutputs();
                neuronIndexInLayer = 0;
                for (Neuron neuron : network.getOutputLayer()) {
                    double neuronError = neuron.getOutput() - expectedOutput[neuronIndexInLayer];
                    deltas[network.getLayerCount() - 2][neuronIndexInLayer] = neuronError *= data.weight;
                    ++neuronIndexInLayer;
                }
                for (layerIndex = network.getLayerCount() - 2; layerIndex >= 1; --layerIndex) {
                    int layerIdx = layerIndex;
                    NeuronLayer layer = network.getLayer(layerIndex);
                    for (int nodei = 0; nodei < layer.size(); ++nodei) {
                        double delta;
                        Neuron neuron = layer.getNeurons().get(nodei);
                        double sum = neuron.getOutputs().stream().mapToDouble(link -> link.getWeight() * deltas[layerIdx][link.getTo().indexInLayer]).sum();
                        double gPrim = neuron.getOutput() * (1.0 - neuron.getOutput());
                        deltas[layerIndex - 1][nodei] = delta = sum * gPrim;
                    }
                }
                for (layerIndex = 0; layerIndex < network.getLayerCount() - 1; ++layerIndex) {
                    NeuronLayer layer = network.getLayer(layerIndex);
                    NeuronLayer nextLayer = network.getLayer(layerIndex + 1);
                    for (int i = 0; i < nextLayer.getNeurons().size(); ++i) {
                        for (int j = 0; j < layer.getNeurons().size() + 1; ++j) {
                            double wantedDeltaValue = deltas[layerIndex][i];
                            double value = 1.0;
                            if (j > 0) {
                                value = layer.getNeurons().get(j - 1).getOutput();
                            }
                            double capitalD = value * wantedDeltaValue;
                            NeuronLink link2 = nextLayer.getNeurons().get(i).getInputs().get(j);
                            link2.setWeight(link2.getWeight() - this.learningRate * capitalD);
                        }
                    }
                }
                cost += Backpropagation.costFunction(network, data);
            }
            double regularizationTerm = 0.0;
            cost = -cost / (double)examples.size() + regularizationTerm;
            if (iterations % this.logRate != 0) continue;
            System.out.printf("Stochastic BackPropagation %s iteration %d : cost %f%n", Arrays.toString(layerSizes), iterations, cost);
        } while (iterations <= targetIterations);
        return network;
    }

    public NeuralNetwork backPropagationLearning(Collection<LearningData> examples, NeuralNetwork network, Consumer<NeuralNetwork> weightsInitialization) {
        int[] layerSizes = network.getLayers().stream().mapToInt(it -> it.size()).toArray();
        int iterations = 0;
        if (weightsInitialization != null) {
            weightsInitialization.accept(network);
        }
        double[][] deltas = new double[network.getLayerCount() - 1][];
        for (int layeri = 0; layeri < network.getLayerCount() - 1; ++layeri) {
            NeuronLayer layer = network.getLayer(layeri + 1);
            deltas[layeri] = new double[layer.size()];
        }
        double[][][] capitalDeltas = new double[network.getLayerCount() - 1][][];
        for (int layerIndex = 0; layerIndex < network.getLayerCount() - 1; ++layerIndex) {
            NeuronLayer layer = network.getLayer(layerIndex + 1);
            capitalDeltas[layerIndex] = new double[layer.size()][];
            for (int i = 0; i < layer.getNeurons().size(); ++i) {
                Neuron neuron = layer.getNeurons().get(i);
                capitalDeltas[layerIndex][i] = new double[neuron.getInputs().size()];
            }
        }
        double previousCost = Double.MAX_VALUE;
        do {
            Backpropagation.zero(deltas);
            Backpropagation.zero(capitalDeltas);
            ++iterations;
            double cost = 0.0;
            for (LearningData data : examples) {
                int layerIndex;
                int neuronIndexInLayer = 0;
                for (Neuron neuron : network.getInputLayer()) {
                    neuron.setOutput(data.getInput(neuronIndexInLayer++));
                }
                for (int layerIndex2 = 1; layerIndex2 < network.getLayerCount(); ++layerIndex2) {
                    NeuronLayer layer = network.getLayer(layerIndex2);
                    Iterator<Neuron> iterator = layer.iterator();
                    while (iterator.hasNext()) {
                        Neuron node = iterator.next();
                        node.process();
                    }
                }
                double[] expectedOutput = data.getOutputs();
                neuronIndexInLayer = 0;
                for (Neuron neuron : network.getOutputLayer()) {
                    double neuronError = neuron.getOutput() - expectedOutput[neuronIndexInLayer];
                    deltas[network.getLayerCount() - 2][neuronIndexInLayer] = neuronError *= data.weight;
                    ++neuronIndexInLayer;
                }
                for (layerIndex = network.getLayerCount() - 2; layerIndex >= 1; --layerIndex) {
                    int layerIdx = layerIndex;
                    NeuronLayer layer = network.getLayer(layerIndex);
                    for (int nodei = 0; nodei < layer.size(); ++nodei) {
                        double delta;
                        Neuron neuron = layer.getNeurons().get(nodei);
                        double sum = neuron.getOutputs().stream().mapToDouble(link -> link.getWeight() * deltas[layerIdx][link.getTo().indexInLayer]).sum();
                        double gPrim = neuron.getOutput() * (1.0 - neuron.getOutput());
                        deltas[layerIndex - 1][nodei] = delta = sum * gPrim;
                    }
                }
                for (layerIndex = 0; layerIndex < network.getLayerCount() - 1; ++layerIndex) {
                    NeuronLayer layer = network.getLayer(layerIndex);
                    NeuronLayer nextLayer = network.getLayer(layerIndex + 1);
                    for (int i = 0; i < nextLayer.getNeurons().size(); ++i) {
                        int j = 0;
                        while (j < layer.getNeurons().size() + 1) {
                            double wantedDeltaValue = deltas[layerIndex][i];
                            double value = 1.0;
                            if (j > 0) {
                                value = layer.getNeurons().get(j - 1).getOutput();
                            }
                            double[] dArray = capitalDeltas[layerIndex][i];
                            int n = j++;
                            dArray[n] = dArray[n] + value * wantedDeltaValue;
                        }
                    }
                }
                cost += Backpropagation.costFunction(network, data);
            }
            double totalChange = 0.0;
            for (int l = 0; l < capitalDeltas.length; ++l) {
                NeuronLayer nextLayer = network.getLayer(l + 1);
                double totalLayerChange = 0.0;
                for (int i = 0; i < capitalDeltas[l].length; ++i) {
                    for (int j = 0; j < capitalDeltas[l][i].length; ++j) {
                        double regularization = 0.0;
                        NeuronLink link2 = nextLayer.getNeurons().get(i).getInputs().get(j);
                        double capitalD = 1.0 / (double)examples.size() * (capitalDeltas[l][i][j] + regularization);
                        totalChange += Math.abs(capitalD);
                        totalLayerChange += Math.abs(capitalD);
                        link2.setWeight(link2.getWeight() - this.learningRate * capitalD);
                    }
                }
            }
            double regularizationTerm = 0.0;
            cost = -cost / (double)examples.size() + regularizationTerm;
            if (this.costFunctionCheck) {
                if (cost > previousCost) {
                    throw new IllegalStateException("Cost function increased from " + previousCost + " to " + cost + ": Network not learning correctly.");
                }
                previousCost = cost;
            }
            if (iterations % this.logRate != 0) continue;
            System.out.printf("BackPropagation %s iteration %d : change %f, cost %f%n", Arrays.toString(layerSizes), iterations, totalChange, cost);
        } while (iterations <= this.iterations);
        return network;
    }

    static double gradientCheck(NeuronLink link, NeuralNetwork network, Collection<LearningData> datas) {
        double originalWeight = link.getWeight();
        double EPSILON = 1.0E-4;
        link.setWeight(originalWeight + 1.0E-4);
        double costPlus = Backpropagation.costFunction(network, datas);
        link.setWeight(originalWeight - 1.0E-4);
        double costMinus = Backpropagation.costFunction(network, datas);
        link.setWeight(originalWeight);
        return (costPlus - costMinus) / 2.0E-4;
    }

    static void zero(double[][] doubles) {
        for (double[] layer : doubles) {
            Arrays.fill(layer, 0.0);
        }
    }

    static void zero(double[][][] doubles) {
        for (double[][] layer : doubles) {
            Backpropagation.zero(layer);
        }
    }

    static double costFunction(NeuralNetwork network, Collection<LearningData> datas) {
        double sum = 0.0;
        for (LearningData data : datas) {
            network.run(data.getInputs());
            sum += Backpropagation.costFunction(network, data);
        }
        return -1.0 / (double)datas.size() * sum;
    }

    static double costFunction(NeuralNetwork network, LearningData learningData) {
        double sum = 0.0;
        double[] out = learningData.getOutputs();
        NeuronLayer outputLayer = network.getOutputLayer();
        for (int i = 0; i < out.length; ++i) {
            double expected = out[i];
            double actual = outputLayer.getNeurons().get(i).getOutput();
            sum += Backpropagation.logisticCost(expected, actual);
        }
        return sum;
    }

    static double logisticCost(double expected, double actual) {
        return expected * Math.log(actual) + (1.0 - expected) * Math.log(1.0 - actual);
    }

    public Backpropagation setLogRate(int logRate) {
        this.logRate = logRate;
        return this;
    }

    public int getLogRate() {
        return this.logRate;
    }

    public static Consumer<NeuralNetwork> initializeRandomOffset(Random random, double v) {
        return network -> network.links().forEach(it -> it.setWeight(it.getWeight() + random.nextDouble() * v * 2.0 - v));
    }

    public void setCostFunctionCheck(boolean costFunctionCheck) {
        this.costFunctionCheck = costFunctionCheck;
    }
}

