/*
 * Decompiled with CFR 0.152.
 */
package uk.co.cablepost.bodkin_boats.ai.train_find_racing_line.bills;

import java.util.ArrayList;
import java.util.List;
import javax.annotation.concurrent.GuardedBy;
import net.fabricmc.fabric.api.client.rendering.v1.WorldRenderContext;
import net.minecraft.class_1799;
import net.minecraft.class_1802;
import net.minecraft.class_1935;
import net.minecraft.class_1937;
import net.minecraft.class_2338;
import net.minecraft.class_241;
import net.minecraft.class_4587;
import net.minecraft.server.MinecraftServer;
import uk.co.cablepost.bodkin_boats.BodkinBoatsClient;
import uk.co.cablepost.bodkin_boats.ai.basic_nerual_network.NeuralNetwork;
import uk.co.cablepost.bodkin_boats.ai.train_find_racing_line.bills.BillsRacingLineScoringFunc;
import uk.co.cablepost.bodkin_boats.ai.world_data.GetWorldData;
import uk.co.cablepost.bodkin_boats.ai.world_data.GetWorldDataCache;

public class TrainFindRacingLineBills {
    private final int yLevel;
    private final float startX;
    private final float startZ;
    private final float startingRot;
    private final GetWorldDataCache getWorldDataCache;
    private NeuralNetwork bestNeuralNetwork;
    private float bestScore = -3.4028235E38f;
    @GuardedBy(value="this")
    public List<class_241> bestPredictedLine = new ArrayList<class_241>();
    private static final int INPUT_SIZE = 243;
    private static final float DISTANCE_BETWEEN_WAYPOINTS = 0.5f;
    public Thread trainingThread;

    public TrainFindRacingLineBills(class_1937 world, class_2338 startingPos, float startingRot) {
        this.yLevel = startingPos.method_10264();
        this.startX = (float)startingPos.method_10263() + 0.5f;
        this.startZ = (float)startingPos.method_10260() + 0.5f;
        this.startingRot = startingRot + 90.0f;
        this.getWorldDataCache = new GetWorldDataCache(world, (int)Math.floor(this.startX) - 500, (int)Math.floor(this.startZ) - 500, (int)Math.floor(this.startX) + 500, (int)Math.floor(this.startZ) + 500, this.yLevel);
        this.bestNeuralNetwork = new NeuralNetwork(243, new int[]{32, 32}, 1);
        BodkinBoatsClient.AI_MINIMAP_DEBUG = null;
        this.trainingThread = new Thread(this::train);
        this.trainingThread.start();
    }

    public void onTick(MinecraftServer server) {
    }

    public void onRender(WorldRenderContext worldRenderContext) {
        class_1799[] itemStacks = new class_1799[]{new class_1799((class_1935)class_1802.field_8353), new class_1799((class_1935)class_1802.field_8672), new class_1799((class_1935)class_1802.field_8455)};
        int i = 0;
        List<class_241> line = this.getBestPredictedLine();
        if (line != null) {
            for (class_241 pos : line) {
                class_4587 matrixStack = worldRenderContext.matrixStack();
                if (matrixStack == null) {
                    return;
                }
                matrixStack.method_22903();
                matrixStack.method_46416(pos.field_1343 - 0.5f, (float)this.yLevel - 0.6f, pos.field_1342 - 0.5f);
                BodkinBoatsClient.renderItem(worldRenderContext, new class_2338(0, 0, 0), itemStacks[i % 3], 1.5f, 0.0f, true);
                matrixStack.method_22909();
                ++i;
            }
        }
    }

    public synchronized List<class_241> getBestPredictedLine() {
        return this.bestPredictedLine;
    }

    public synchronized void setBestPredictedLine(List<class_241> bestPredictedLine) {
        this.bestPredictedLine = bestPredictedLine;
    }

    private double[] getInputs(float x, float z, float rotation, float lastRelativeRotation) {
        float[] floats;
        int n;
        double[] inputs = new double[243];
        float[][] minimap = GetWorldData.getMinimap(this.getWorldDataCache, new class_2338((int)Math.floor(x), this.yLevel, (int)Math.floor(z)), (float)Math.toDegrees(rotation) + 90.0f, 5, 0);
        float[][] minimap2 = GetWorldData.getMinimap(this.getWorldDataCache, new class_2338((int)Math.floor(x), this.yLevel, (int)Math.floor(z)), (float)Math.toDegrees(rotation) + 90.0f, 5, 7);
        if (BodkinBoatsClient.AI_MINIMAP_DEBUG == null) {
            try {
                BodkinBoatsClient.AI_MINIMAP_DEBUG = minimap2;
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        int i = 0;
        float[][] fArray = minimap;
        int n2 = fArray.length;
        for (n = 0; n < n2; ++n) {
            for (float aFloat : floats = fArray[n]) {
                inputs[i] = aFloat;
                ++i;
            }
        }
        fArray = minimap2;
        n2 = fArray.length;
        for (n = 0; n < n2; ++n) {
            for (float aFloat : floats = fArray[n]) {
                inputs[i] = aFloat;
                ++i;
            }
        }
        inputs[i] = lastRelativeRotation;
        return inputs;
    }

    private void predictNext(List<class_241> line, NeuralNetwork neuralNetwork) {
        float x = line.getLast().field_1343;
        float z = line.getLast().field_1342;
        float rotation = TrainFindRacingLineBills.getRadAngleFromAToB(line.getLast(), line.get(line.size() - 2));
        float lastRelativeRotation = TrainFindRacingLineBills.wrapRadAngle(rotation - TrainFindRacingLineBills.getRadAngleFromAToB(line.get(line.size() - 2), line.get(line.size() - 3))) / ((float)Math.PI * 2) + 0.5f;
        double[] nnInputs = this.getInputs(x, z, rotation, lastRelativeRotation);
        double[] nnOutputs = neuralNetwork.calc(nnInputs);
        float predictedAngle = (float)((nnOutputs[0] - 0.5) * Math.PI * (double)0.1f);
        line.add(TrainFindRacingLineBills.getPosAtAngle(line.getLast(), rotation + predictedAngle + (float)Math.PI, 0.5f));
    }

    public List<class_241> predictLine(int length, NeuralNetwork neuralNetwork) {
        ArrayList<class_241> line = new ArrayList<class_241>();
        line.add(TrainFindRacingLineBills.getPosAtAngle(new class_241(this.startX, this.startZ), (float)Math.toRadians(this.startingRot + 180.0f), 1.0f));
        line.add(TrainFindRacingLineBills.getPosAtAngle(new class_241(this.startX, this.startZ), (float)Math.toRadians(this.startingRot + 180.0f), 0.5f));
        line.add(new class_241(this.startX, this.startZ));
        for (int i = 0; i < length; ++i) {
            this.predictNext(line, neuralNetwork);
        }
        return line;
    }

    private static float wrapRadAngle(float angle) {
        while ((double)angle < -Math.PI) {
            angle += (float)Math.PI * 2;
        }
        while ((double)angle > Math.PI) {
            angle -= (float)Math.PI * 2;
        }
        return angle;
    }

    private static class_241 getPosAtAngle(class_241 pos, float angleRad, float distance) {
        return new class_241(pos.field_1343 + (float)(Math.cos(angleRad) * (double)distance), pos.field_1342 + (float)(Math.sin(angleRad) * (double)distance));
    }

    private static float getRadAngleFromAToB(class_241 a, class_241 b) {
        float dx = b.field_1343 - a.field_1343;
        float dz = b.field_1342 - a.field_1342;
        return (float)Math.atan2(dz, dx);
    }

    private static float getRelativeRotation(List<class_241> line, int i) {
        if (i < 2) {
            System.out.println("Warning - Tried to calculate getRelativeRotation on an index lower than 2");
            return 0.0f;
        }
        float rotationCurrent = TrainFindRacingLineBills.getRadAngleFromAToB(line.get(i), line.get(i - 1));
        float rotationPrevious = TrainFindRacingLineBills.getRadAngleFromAToB(line.get(i - 1), line.get(i - 2));
        float angleDifferenceWrapped = TrainFindRacingLineBills.wrapRadAngle(rotationCurrent - rotationPrevious);
        return angleDifferenceWrapped / ((float)Math.PI * 2);
    }

    public void runEpisode(int lineLength, BillsRacingLineScoringFunc scoreFunc) {
        List<class_241> line;
        float score;
        NeuralNetwork neuralNetwork = new NeuralNetwork(this.bestNeuralNetwork);
        if (this.bestScore > -3.4028235E38f) {
            neuralNetwork.mutate();
        }
        if ((score = scoreFunc.score(line = this.predictLine(lineLength, neuralNetwork))) > this.bestScore) {
            this.setBestPredictedLine(line);
            if (this.bestScore > -3.4028235E38f) {
                System.out.println("New best score: " + score);
                this.bestNeuralNetwork = neuralNetwork;
            } else {
                System.out.println("Baseline score: " + score);
            }
            this.bestScore = score;
        }
    }

    public void train() {
        System.out.println("Started training...");
        System.out.println("PHASE 1 - Learning to make a straight line");
        while (this.bestScore < -0.15f && !Thread.currentThread().isInterrupted()) {
            this.runEpisode(10, line -> {
                float score = 0.0f;
                for (int i = 2; i < line.size(); ++i) {
                    float lastRelativeRotation = TrainFindRacingLineBills.getRelativeRotation(line, i);
                    score -= Math.abs(lastRelativeRotation) * 20.0f / (float)line.size();
                }
                return score;
            });
        }
        this.bestScore = -3.4028235E38f;
        System.out.println("PHASE 2 - Learning to stay in the middle of the track (short)");
        while (this.bestScore < -65.0f && !Thread.currentThread().isInterrupted()) {
            this.runEpisode(100, this::scorePhase2);
        }
        this.bestScore = -3.4028235E38f;
        System.out.println("PHASE 2.1 - Learning to stay in the middle of the track (medium)");
        while (this.bestScore < -800.0f && !Thread.currentThread().isInterrupted()) {
            this.runEpisode(1000, this::scorePhase2);
        }
        this.bestScore = -3.4028235E38f;
        System.out.println("PHASE 2.2 - Learning to stay in the middle of the track (long)");
        while (this.bestScore < -1450.0f && !Thread.currentThread().isInterrupted()) {
            this.runEpisode(2000, this::scorePhase2);
        }
        this.bestScore = -3.4028235E38f;
        System.out.println("PHASE 2.3 - Learning to stay in the middle of the track (mega)");
        while (this.bestScore < -3000.0f && !Thread.currentThread().isInterrupted()) {
            this.runEpisode(10000, this::scorePhase2);
        }
        this.bestScore = -3.4028235E38f;
        System.out.println("PHASE 3 - Learning to stay away from the edge of the track, and smooth");
        while (this.bestScore < 1000000.0f && !Thread.currentThread().isInterrupted()) {
            this.runEpisode(10000, line -> {
                float score = 0.0f;
                float lastRelativeRotationBeforeThat = -100.0f;
                for (int i = 2; i < line.size(); ++i) {
                    float lastRelativeRotation = TrainFindRacingLineBills.getRelativeRotation(line, i);
                    if (Math.abs(lastRelativeRotation) > 0.25f) {
                        score -= 10000.0f;
                    }
                    if (lastRelativeRotationBeforeThat < -99.0f) {
                        lastRelativeRotationBeforeThat = lastRelativeRotation;
                    }
                    score -= (float)Math.pow(Math.abs(lastRelativeRotation - lastRelativeRotationBeforeThat) * 10.0f, 2.0);
                    lastRelativeRotationBeforeThat = lastRelativeRotation;
                    class_2338 blockPos = new class_2338((int)Math.floor(((class_241)line.get((int)i)).field_1343), this.yLevel, (int)Math.floor(((class_241)line.get((int)i)).field_1342));
                    float slip = this.getWorldDataCache.getBlockSlip(blockPos);
                    score += slip * 10.0f;
                    if (slip < 0.7f) {
                        score -= 1000.0f;
                    }
                    for (int j = Math.max(0, i - 1000); j < i - 50; j += 20) {
                        if (!(((class_241)line.get(i)).method_35589((class_241)line.get(j)) < 3.0f)) continue;
                        score -= 10000.0f;
                    }
                }
                return score;
            });
        }
        System.out.println("Training interrupted");
    }

    private float scorePhase2(List<class_241> line) {
        float score = 0.0f;
        for (int i = 2; i < line.size(); ++i) {
            float lastRelativeRotation = TrainFindRacingLineBills.getRelativeRotation(line, i);
            score -= Math.abs(lastRelativeRotation) * 2.0f;
            class_2338 blockPos = new class_2338((int)Math.floor(line.get((int)i).field_1343), this.yLevel, (int)Math.floor(line.get((int)i).field_1342));
            float closestTrackEdge = 0.0f;
            for (int angle = 0; angle < 360; angle += 30) {
                float trackEdge = GetWorldData.raycast01(this.getWorldDataCache, blockPos, angle, 16);
                if (!(trackEdge > closestTrackEdge)) continue;
                closestTrackEdge = trackEdge;
            }
            score -= closestTrackEdge;
            if (closestTrackEdge >= 0.99f) {
                score -= 1000.0f;
            }
            for (int j = Math.max(0, i - 1000); j < i - 50; j += 20) {
                if (!(line.get(i).method_35589(line.get(j)) < 3.0f)) continue;
                score -= 10000.0f;
            }
        }
        return score;
    }
}

