/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.machlearn.regressionvectorized;

import net.zomis.machlearn.regression.ModelFunction;
import net.zomis.machlearn.regressionvectorized.LinearRegression;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class LogisticRegression {
    public static double sigmoid(double z) {
        double n = 1.0 + Math.exp(-z);
        return 1.0 / n;
    }

    public static DoubleMatrix sigmoid(DoubleMatrix z) {
        DoubleMatrix n = DoubleMatrix.ones((int)z.rows, (int)z.columns).div(MatrixFunctions.expi((DoubleMatrix)z.neg()).add(1.0));
        return n;
    }

    public static double hypothesis(double[] theta, double[] x) {
        double thetaX = LinearRegression.linearHypothesis(theta, x);
        return LogisticRegression.sigmoid(thetaX);
    }

    public static DoubleMatrix hypothesis(DoubleMatrix T, DoubleMatrix X) {
        return LogisticRegression.sigmoid(X.mmul(T));
    }

    public static ModelFunction costFunction(final double[][] x, final double[] y) {
        if (x.length != y.length) {
            throw new IllegalArgumentException();
        }
        return new ModelFunction(){

            @Override
            public double apply(double[] theta) {
                DoubleMatrix X = DoubleMatrix.concatHorizontally((DoubleMatrix)DoubleMatrix.ones((int)x.length, (int)1), (DoubleMatrix)new DoubleMatrix(x));
                DoubleMatrix Y = new DoubleMatrix(y);
                DoubleMatrix T = new DoubleMatrix(theta);
                DoubleMatrix H = LogisticRegression.hypothesis(T, X);
                int m = x.length;
                DoubleMatrix result1 = Y.transpose().neg().mul(MatrixFunctions.log((DoubleMatrix)H));
                DoubleMatrix result2 = DoubleMatrix.ones((int)Y.columns, (int)Y.rows).sub(Y.transpose()).mul(MatrixFunctions.log((DoubleMatrix)DoubleMatrix.ones((int)H.rows, (int)H.columns).sub(H)));
                Double result = result1.sub(result2).sum() / (double)m;
                return result;
            }
        };
    }

    public static ModelFunction costFunctionOld(final double[][] x, final double[] y) {
        if (x.length != y.length) {
            throw new IllegalArgumentException();
        }
        return new ModelFunction(){

            @Override
            public double apply(double[] theta) {
                double sum = 0.0;
                for (int i = 0; i < x.length; ++i) {
                    double current;
                    double yValue = y[i];
                    double hypValue = LogisticRegression.hypothesis(theta, x[i]);
                    if (yValue == 1.0) {
                        current = -Math.log(hypValue);
                    } else if (yValue == 0.0) {
                        current = -Math.log(1.0 - hypValue);
                    } else {
                        throw new IllegalArgumentException("y must be either 0 or 1 but was " + yValue);
                    }
                    sum += current;
                }
                return sum / (double)x.length;
            }
        };
    }
}

