/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import net.zomis.machlearn.clustering.KMeansResult;

public class KMeans {
    public static void main(String[] args) {
        Random random = new Random(42L);
        double[][] inputs = new double[12][2];
        for (int i = 0; i < inputs.length; ++i) {
            inputs[i] = new double[]{random.nextDouble(), random.nextDouble()};
        }
        System.out.println("a = [");
        Arrays.stream(inputs).forEach(d -> System.out.println(Arrays.toString(d) + ";"));
        System.out.println(']');
        KMeansResult clusters = KMeans.cluster(inputs, 2, 100, random);
        System.out.println("clusters = " + Arrays.toString(clusters.getClusters()) + ';');
        System.out.println("a(:,4) = clusters'");
    }

    public static KMeansResult cluster(double[][] inputs, int clusterCount, int repetitions, Random random) {
        KMeansResult bestClusters = null;
        double bestCost = 0.0;
        for (int iteration = 0; iteration < repetitions; ++iteration) {
            KMeansResult result = KMeans.performClustering(inputs, clusterCount, random);
            int[] clusters = result.getClusters();
            double[][] centroids = result.getCentroids();
            double totalCost = 0.0;
            for (int i = 0; i < inputs.length; ++i) {
                int cluster = clusters[i];
                double[] centroid = centroids[cluster];
                double distance = KMeans.eucledianDistanceSquared(inputs[i], centroid);
                totalCost += distance;
            }
            if (bestClusters != null && !(totalCost < bestCost)) continue;
            bestCost = totalCost;
            bestClusters = result;
        }
        return bestClusters;
    }

    private static KMeansResult performClustering(double[][] inputs, int clusterCount, Random random) {
        int[] clusters = new int[inputs.length];
        double[][] centroids = new double[clusterCount][inputs[0].length];
        int[] trainingSetCentroids = new int[centroids.length];
        for (int i = 0; i < centroids.length; ++i) {
            int trainingSet;
            do {
                trainingSetCentroids[i] = trainingSet = random.nextInt(inputs.length);
            } while (KMeans.isTaken(trainingSetCentroids, i, trainingSet));
            centroids[i] = Arrays.copyOf(inputs[trainingSet], inputs[trainingSet].length);
        }
        boolean changed = true;
        while (changed) {
            changed = KMeans.changeClusters(centroids, clusters, inputs);
            KMeans.moveCentroids(centroids, clusters, inputs);
        }
        return new KMeansResult(clusters, centroids);
    }

    private static void moveCentroids(double[][] centroids, int[] clusters, double[][] inputs) {
        int i;
        ArrayList trainingSetsInCluster = new ArrayList(centroids.length);
        for (i = 0; i < centroids.length; ++i) {
            trainingSetsInCluster.add(new ArrayList());
        }
        for (i = 0; i < inputs.length; ++i) {
            int cluster = clusters[i];
            ((List)trainingSetsInCluster.get(cluster)).add(i);
        }
        for (int c = 0; c < trainingSetsInCluster.size(); ++c) {
            double[] sums = new double[inputs[0].length];
            List trainingSets = (List)trainingSetsInCluster.get(c);
            Iterator iterator = trainingSets.iterator();
            while (iterator.hasNext()) {
                int i2 = (Integer)iterator.next();
                for (int j = 0; j < inputs[i2].length; ++j) {
                    int n = j;
                    sums[n] = sums[n] + inputs[i2][j];
                }
            }
            centroids[c] = Arrays.stream(sums).map(d -> d / (double)trainingSets.size()).toArray();
        }
    }

    private static boolean changeClusters(double[][] centroids, int[] clusters, double[][] inputs) {
        boolean changed = false;
        for (int i = 0; i < inputs.length; ++i) {
            int oldCluster = clusters[i];
            clusters[i] = KMeans.findClosestCluster(inputs[i], centroids);
            changed = changed || oldCluster != clusters[i];
        }
        return changed;
    }

    private static int findClosestCluster(double[] input, double[][] centroids) {
        double minDistance = KMeans.eucledianDistanceSquared(input, centroids[0]);
        int closestIndex = 0;
        for (int i = 1; i < centroids.length; ++i) {
            double distance = KMeans.eucledianDistanceSquared(input, centroids[i]);
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            closestIndex = i;
        }
        return closestIndex;
    }

    public static double eucledianDistanceSquared(double[] input, double[] centroid) {
        if (input.length != centroid.length) {
            throw new IllegalArgumentException("Values must be of same length. Input has length " + input.length + "while centroid has length " + centroid.length);
        }
        double sum = 0.0;
        for (int i = 0; i < input.length; ++i) {
            double diff = input[i] - centroid[i];
            sum += diff * diff;
        }
        return sum;
    }

    private static boolean isTaken(int[] centroids, int upToIndex, int current) {
        for (int i = 0; i < upToIndex; ++i) {
            if (centroids[i] != current) continue;
            return true;
        }
        return false;
    }
}

