/*
 * Decompiled with CFR 0.152.
 */
package net.zomis.gameai;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.zomis.gameai.Feature;
import net.zomis.gameai.FeatureExtractor;
import net.zomis.gameai.FeatureExtractors;
import net.zomis.gameai.FeatureScaling;
import net.zomis.gameai.GameMove;
import net.zomis.gameai.TrainingData;
import net.zomis.gameai.features.IntegerFeature;
import net.zomis.machlearn.neural.Backpropagation;
import net.zomis.machlearn.neural.LearningData;
import net.zomis.machlearn.neural.NeuralNetwork;
import net.zomis.machlearn.neural.NeuronLayer;

public class GameAI {
    private final String name;
    private NeuralNetwork network;
    private final Map<Class<?>, FeatureExtractors<?>> extractors = new HashMap();
    private final List<List<TrainingData>> featureValues = new ArrayList<List<TrainingData>>();
    private final Random random = new Random(42L);
    private List<TrainingData> currentGame;

    public GameAI(String name) {
        this.name = name;
    }

    public GameMove makeMove(Random random, GameMove[] moves) {
        if (this.network == null) {
            return this.makeRandomMove(random, moves);
        }
        GameMove bestMove = null;
        double bestScore = -1.0;
        int bestMoveIndex = 0;
        for (int i = 0; i < moves.length; ++i) {
            double score;
            GameMove move = moves[i];
            if (!move.isAllowed() || !((score = this.evaluateMove(moves, i)) > bestScore)) continue;
            bestMove = move;
            bestScore = score;
            bestMoveIndex = i;
        }
        if (bestMove == null) {
            throw new IllegalStateException("No valid moves found");
        }
        bestMove.perform();
        this.storeAction(moves.length, bestMoveIndex);
        return bestMove;
    }

    public double evaluateMove(GameMove[] moves, int index) {
        double[] oldX = this.currentGame.get(this.currentGame.size() - 1).getX();
        double[] data = new double[moves.length];
        data[index] = 1.0;
        double[] x = Arrays.copyOf(oldX, oldX.length + data.length);
        for (int i = oldX.length; i < x.length; ++i) {
            x[i] = data[i - oldX.length];
        }
        double[] output = this.network.run(x);
        System.out.println("RUN: " + index + " INPUT: " + Arrays.toString(x) + " OUTPUT: " + Arrays.toString(output));
        return output[0];
    }

    public GameMove makeRandomMove(Random random, GameMove[] moves) {
        List allowedMoves = Arrays.stream(moves).filter(GameMove::isAllowed).collect(Collectors.toList());
        if (allowedMoves.isEmpty()) {
            throw new IllegalStateException("No move allowed");
        }
        int index = random.nextInt(allowedMoves.size());
        GameMove action = moves[index];
        action.perform();
        int moveIndex = this.indexOf(moves, action);
        this.storeAction(moves.length, moveIndex);
        return action;
    }

    private void storeAction(int moveCount, int moveIndex) {
        double[] moveDouble = new double[moveCount];
        moveDouble[moveIndex] = 1.0;
        if (this.currentGame != null) {
            this.currentGame.get(this.currentGame.size() - 1).expandX(moveDouble);
        }
    }

    private int indexOf(GameMove[] moves, GameMove action) {
        for (int i = 0; i < moves.length; ++i) {
            if (moves[i] != action) continue;
            return i;
        }
        return -1;
    }

    public <E> void inform(E object) {
        this.initializeCurrentGame();
        this.extractors.putIfAbsent(object.getClass(), new FeatureExtractors(object.getClass()));
        FeatureExtractors<?> featureExtractors = this.extractors.get(object.getClass());
        List<Feature<?>> features = featureExtractors.extract(object);
        int size = features.stream().mapToInt(Feature::getSize).sum();
        double[] values = new double[size];
        int i = 0;
        for (Feature<?> feature : features) {
            for (int f = 0; f < feature.getSize(); ++f) {
                double value = feature.toDouble(f);
                values[i++] = value;
            }
        }
        TrainingData data = new TrainingData(values);
        this.currentGame.add(data);
    }

    private void initializeCurrentGame() {
        if (this.currentGame != null) {
            return;
        }
        this.currentGame = new ArrayList<TrainingData>();
    }

    public void endGameWithScore(int score) {
        if (this.currentGame == null) {
            return;
        }
        double min = 0.6;
        double max = 1.0;
        double increase = (max - min) / (double)(this.currentGame.size() - 1);
        double[] y = new double[]{score};
        for (int i = 0; i < this.currentGame.size(); ++i) {
            TrainingData data = this.currentGame.get(i);
            data.setY(y);
            double weight = (double)i * increase + min;
            data.setWeight(weight);
        }
        FeatureScaling.scale(this.currentGame);
        this.featureValues.add(this.currentGame);
        this.currentGame = null;
    }

    public void learn() {
        int inputs = this.featureValues.stream().flatMap(Collection::stream).mapToInt(td -> td.getX().length).max().orElse(0);
        int hidden1 = (int)Math.ceil((double)inputs / 2.0);
        int hidden2 = (int)Math.ceil((double)inputs / 3.0);
        NeuralNetwork nn = NeuralNetwork.createNetwork(inputs, hidden1, hidden2, 1);
        Backpropagation backprop = new Backpropagation(0.1, 100000);
        backprop.setLogRate(10);
        List<LearningData> data = this.featureValues.stream().flatMap(Collection::stream).map(td -> new LearningData(td.getX(), td.getY())).collect(Collectors.toList());
        int[] layers = nn.layers.stream().mapToInt(NeuronLayer::size).toArray();
        System.out.println("Learning using " + data.size() + " training examples. " + "Layer sizes are " + Arrays.toString(layers));
        Consumer<NeuralNetwork> initalization = Backpropagation.initializeRandom(this.random);
        backprop.stohasticGradientDescent(data, nn, initalization, this.random, 100);
        this.network = nn;
    }

    public <E, F> void addFeatureExtractor(Class<E> clazz, String name, Class<F> featureClass, Function<E, F> valueRetriever) {
        if (featureClass == Integer.class) {
            this.addFeatureExtractor(clazz, new IntegerFeature<Object>(name, e -> (Integer)valueRetriever.apply(e), 6, false));
        }
    }

    public String toString() {
        return this.name;
    }

    public <E> void addFeatureExtractor(Class<E> clazz, FeatureExtractor<E> feature) {
        this.extractors.putIfAbsent(clazz, new FeatureExtractors<E>(clazz));
        FeatureExtractors<?> fe = this.extractors.get(clazz);
        fe.add(feature);
    }
}

