/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.regression;

import java.util.Arrays;
import java.util.function.Predicate;
import net.zomis.machlearn.regression.ModelFunction;

public class GradientDescent {
    public static double[] partialDerivatives(ModelFunction function, double[] x) {
        if (x == null || x.length == 0) {
            throw new IllegalArgumentException("Cannot calculate derivative without parameters");
        }
        double H = 1.0E-7;
        double[] result = new double[x.length];
        double[] x2 = Arrays.copyOf(x, x.length);
        for (int i = 0; i < x.length; ++i) {
            double fx = function.apply(x);
            int n = i;
            x2[n] = x2[n] + 1.0E-7;
            double fxh = function.apply(x2);
            x2[i] = x[i];
            result[i] = (fxh - fx) / 1.0E-7;
        }
        return result;
    }

    public static double[] gradientDescent(ModelFunction costFunction, Predicate<double[]> convergenceCondition, double[] initialTheta, double alpha) {
        double[] theta = Arrays.copyOf(initialTheta, initialTheta.length);
        double[] newTheta = new double[theta.length];
        while (!convergenceCondition.test(theta)) {
            double[] derivate = GradientDescent.partialDerivatives(costFunction, theta);
            for (int i = 0; i < theta.length; ++i) {
                newTheta[i] = theta[i] - alpha * derivate[i];
            }
            double[] temp = theta;
            theta = newTheta;
            newTheta = temp;
        }
        return theta;
    }
}

