/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.neural;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import net.zomis.machlearn.neural.Neuron;
import net.zomis.machlearn.neural.NeuronLayer;
import net.zomis.machlearn.neural.NeuronLink;

public class NeuralNetwork {
    public List<NeuronLayer> layers = new ArrayList<NeuronLayer>();

    public List<NeuronLayer> getLayers() {
        return this.layers;
    }

    public NeuronLayer getInputLayer() {
        return this.getLayer(0);
    }

    public NeuronLayer getOutputLayer() {
        return this.getLayer(this.layers.size() - 1);
    }

    public NeuronLayer getLayer(int layerIndex) {
        return this.layers.get(layerIndex);
    }

    public NeuronLayer createLayer(String name) {
        NeuronLayer layer = new NeuronLayer(name);
        this.layers.add(layer);
        return layer;
    }

    public int getLayerCount() {
        return this.layers.size();
    }

    public Stream<NeuronLink> links() {
        return this.layers.stream().skip(1L).flatMap(it -> it.neurons.stream()).flatMap(it -> it.inputs.stream());
    }

    public void printAll() {
        System.out.println(this.getLayerCount() + " layers:");
        this.layers.stream().forEach(it -> {
            it.printNodes();
            System.out.println();
        });
        System.out.println();
    }

    void save(OutputStream output) {
        try (DataOutputStream it = new DataOutputStream(output);){
            it.writeInt(this.getLayerCount());
            for (NeuronLayer layer : this.layers) {
                it.writeInt(layer.size());
                it.writeInt(layer.name.length());
                it.writeBytes(layer.name);
            }
            for (NeuronLayer layer : this.layers) {
                for (Neuron neuron : layer) {
                    for (NeuronLink link : neuron.inputs) {
                        it.writeDouble(link.getWeight());
                    }
                }
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static NeuralNetwork load(InputStream input) {
        NeuralNetwork network = new NeuralNetwork();
        try (DataInputStream it = new DataInputStream(input);){
            int i;
            int layers = it.readInt();
            for (i = 0; i < layers; ++i) {
                int size = it.readInt();
                int nameLength = it.readInt();
                StringBuilder name = new StringBuilder();
                for (int nameIndex = 0; nameIndex < nameLength; ++nameIndex) {
                    name.append((char)it.readByte());
                }
                NeuronLayer layer = network.createLayer(name.toString());
                for (int j = 0; j < size; ++j) {
                    layer.createNeuron();
                }
            }
            for (i = 0; i < layers; ++i) {
                NeuronLayer layer = network.getLayer(i);
                if (i > 0) {
                    int ii = i;
                    layer.neurons.forEach(it2 -> it2.addInputs(network.getLayer(ii - 1)));
                }
                for (Neuron neuron : layer) {
                    for (NeuronLink link : neuron.inputs) {
                        link.setWeight(it.readDouble());
                    }
                }
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return network;
    }

    public double[] run(double[] input) {
        double[] output = new double[this.getOutputLayer().size()];
        assert (input.length == this.getInputLayer().size());
        for (int i = 0; i < this.getInputLayer().size(); ++i) {
            this.getInputLayer().neurons.get((int)i).output = input[i];
        }
        int layerIndex = 0;
        for (NeuronLayer layer : this.layers) {
            if (layerIndex++ == 0) continue;
            for (Neuron node : layer) {
                node.process();
            }
        }
        for (int i = 0; i < this.getOutputLayer().size(); ++i) {
            output[i] = this.getOutputLayer().neurons.get((int)i).output;
        }
        return output;
    }

    public NeuronLayer getLastLayer() {
        return this.getLayer(this.layers.size() - 1);
    }

    public static NeuralNetwork createNetwork(int ... layerSizes) {
        if (layerSizes.length < 2) {
            throw new IllegalArgumentException("Network layers must be at least 2");
        }
        NeuralNetwork network = new NeuralNetwork();
        NeuronLayer layer = network.createLayer("INPUT");
        for (int i = 0; i < layerSizes[0]; ++i) {
            layer.createNeuron();
        }
        int hiddenIndex = 1;
        for (int layerIndex = 1; layerIndex < layerSizes.length - 1; ++layerIndex) {
            int layerSize = layerSizes[layerIndex];
            NeuronLayer parentLayer = layer;
            layer = network.createLayer("HIDDEN " + hiddenIndex++);
            for (int i = 0; i < layerSize; ++i) {
                layer.createNeuron().addInputs(parentLayer);
            }
        }
        NeuronLayer parentLayer = network.getLastLayer();
        layer = network.createLayer("OUTPUT");
        int outputNodes = layerSizes[layerSizes.length - 1];
        for (int i = 0; i < outputNodes; ++i) {
            layer.createNeuron().addInputs(parentLayer);
        }
        return network;
    }
}

