/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.common;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.zomis.machlearn.common.ClassifierFunction;
import net.zomis.machlearn.common.PartitionedDataSet;
import net.zomis.machlearn.common.PrecisionRecallF1;
import net.zomis.machlearn.neural.LearningData;

public class LearningDataSet {
    private final List<LearningData> data;

    public LearningDataSet() {
        this(new ArrayList<LearningData>());
    }

    public LearningDataSet(List<LearningData> data) {
        this.data = new ArrayList<LearningData>(data);
    }

    public void add(Object representation, double[] x, double y) {
        this.add(representation, x, new double[]{y});
    }

    public void add(Object representation, double[] x, double[] y) {
        this.data.add(new LearningData(representation, x, y));
    }

    public double[][] getXs() {
        return (double[][])this.data.stream().map(LearningData::getInputs).collect(Collectors.toList()).toArray((T[])new double[this.data.size()][]);
    }

    public double[] getY() {
        return this.data.stream().map(LearningData::getOutputs).mapToDouble(d -> d[0]).toArray();
    }

    public int numFeaturesWithZero() {
        return this.data.get(0).getInputs().length + 1;
    }

    public PrecisionRecallF1 precisionRecallF1(double[] theta, ClassifierFunction hypothesis) {
        PrecisionRecallF1 score = new PrecisionRecallF1();
        for (LearningData ld : this.data) {
            boolean prediction = hypothesis.classify(theta, ld.getInputs());
            boolean actual = ld.getOutputs()[0] >= 0.5;
            score.add(actual, prediction);
        }
        return score;
    }

    public List<LearningData> getData() {
        return this.data;
    }

    public Stream<LearningData> stream() {
        return this.data.stream();
    }

    public PartitionedDataSet partition(double trainingSetRatio, double crossValidationSetRatio, double testSetRatio, Random random) {
        ArrayList<LearningData> shuffledData = new ArrayList<LearningData>(this.data);
        Collections.shuffle(shuffledData, random);
        double sum = trainingSetRatio + crossValidationSetRatio + testSetRatio;
        int size = shuffledData.size();
        int indexSplit1 = (int)(trainingSetRatio / sum * (double)size);
        int indexSplit2 = indexSplit1 + (int)(crossValidationSetRatio / sum * (double)size);
        ArrayList<LearningData> trainingSet = new ArrayList<LearningData>(shuffledData.subList(0, indexSplit1));
        ArrayList<LearningData> crossValidationSet = new ArrayList<LearningData>(shuffledData.subList(indexSplit1, indexSplit2));
        ArrayList<LearningData> testSet = new ArrayList<LearningData>(shuffledData.subList(indexSplit2, size));
        return new PartitionedDataSet(trainingSet, crossValidationSet, testSet);
    }
}

