/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.images;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import net.zomis.machlearn.images.ImageAnalysis;
import net.zomis.machlearn.images.ImageNetwork;
import net.zomis.machlearn.neural.Backpropagation;
import net.zomis.machlearn.neural.LearningData;
import net.zomis.machlearn.neural.NeuralNetwork;
import net.zomis.machlearn.neural.NeuronLayer;

public class ImageNetworkBuilder {
    private final NeuralNetwork network;
    private final Map<Object, List<double[]>> classifications = new HashMap<Object, List<double[]>>();
    private final ImageAnalysis analysis;

    public ImageNetworkBuilder(ImageAnalysis analysis, int inputSize, int ... hiddenLayerSizes) {
        this.analysis = analysis;
        this.network = new NeuralNetwork();
        NeuronLayer layer = this.network.createLayer("INPUT");
        for (int i = 0; i < inputSize; ++i) {
            layer.createNeuron();
        }
        int hiddenIndex = 1;
        for (int layerSize : hiddenLayerSizes) {
            NeuronLayer parentLayer = layer;
            layer = this.network.createLayer("HIDDEN " + hiddenIndex++);
            for (int i = 0; i < layerSize; ++i) {
                layer.createNeuron().addInputs(parentLayer);
            }
        }
    }

    public ImageNetworkBuilder classify(Object result, double[] input) {
        this.classifications.putIfAbsent(result, new ArrayList());
        this.classifications.get(result).add(Arrays.copyOf(input, input.length));
        return this;
    }

    public ImageNetwork learn(Backpropagation backprop, Random random) {
        int outputNodes = this.classifications.size();
        if (this.classifications.containsKey(null)) {
            --outputNodes;
        }
        NeuronLayer parentLayer = this.network.getLastLayer();
        NeuronLayer layer = this.network.createLayer("OUTPUT");
        for (int i = 0; i < outputNodes; ++i) {
            layer.createNeuron().addInputs(parentLayer);
        }
        ArrayList<LearningData> learningData = new ArrayList<LearningData>();
        int outputIndex = 0;
        Object[] objects = new Object[outputNodes];
        for (Map.Entry<Object, List<double[]>> entry : this.classifications.entrySet()) {
            double[] outputs = new double[outputNodes];
            if (entry.getKey() != null) {
                outputs[outputIndex] = 1.0;
                objects[outputIndex] = entry.getKey();
                ++outputIndex;
            }
            learningData.addAll(entry.getValue().stream().map(inputs -> new LearningData((double[])inputs, outputs)).collect(Collectors.toList()));
        }
        backprop.backPropagationLearning(learningData, this.network, Backpropagation.initializeRandom(random != null ? random : new Random()));
        return new ImageNetwork(this.analysis, this.network, objects);
    }

    public ImageNetworkBuilder classifyNone(double[] input) {
        this.classifications.putIfAbsent(null, new ArrayList());
        this.classifications.get(null).add(Arrays.copyOf(input, input.length));
        return this;
    }

    public ImageAnalysis getAnalysis() {
        return this.analysis;
    }
}

