/*
 * Decompiled with CFR 0.152.
 */
package uk.co.cablepost.bodkin_boats.ai.train_ev_to_bucket;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Reader;
import java.lang.reflect.Type;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.annotation.concurrent.GuardedBy;
import net.fabricmc.fabric.api.client.rendering.v1.WorldRenderContext;
import net.minecraft.class_156;
import net.minecraft.server.MinecraftServer;
import uk.co.cablepost.bodkin_boats.ai.VirtualBoat;
import uk.co.cablepost.bodkin_boats.ai.basic_nerual_network.NeuralNetwork;
import uk.co.cablepost.bodkin_boats.ai.train_ev_to_bucket.BucketHelper;
import uk.co.cablepost.bodkin_boats.ai.train_ev_to_bucket.BucketValue;
import uk.co.cablepost.bodkin_boats.ai.train_ev_to_bucket.State;

public class TrainEvToBucket {
    public VirtualBoat virtualBoatPreview = new VirtualBoat();
    public Thread trainingThread;
    @GuardedBy(value="this")
    boolean[][] actionsToReplay = null;
    @GuardedBy(value="this")
    int bucketIndexInTraining = 0;
    boolean[][] actionsBeingReplayed = null;
    int bucketIndexBeingReplayed = 0;
    int actionBeingReplayedIndex = 0;
    public static Random RANDOM = new Random();
    private static final Gson GSON = new Gson();

    public TrainEvToBucket(boolean runThread) {
        if (runThread) {
            this.trainingThread = new Thread(this::train);
            this.trainingThread.start();
        }
    }

    public synchronized boolean[][] getActionsToReplay() {
        return this.actionsToReplay;
    }

    public synchronized void setActionsToReplay(boolean[][] actionsToReplay) {
        this.actionsToReplay = actionsToReplay;
    }

    public synchronized int getBucketIndexInTraining() {
        return this.bucketIndexInTraining;
    }

    public synchronized void setBucketIndexInTraining(int bucketIndexInTraining) {
        this.bucketIndexInTraining = bucketIndexInTraining;
    }

    public void onTick(MinecraftServer server) {
        if (this.actionsBeingReplayed == null || this.actionBeingReplayedIndex >= this.actionsBeingReplayed.length) {
            this.actionsBeingReplayed = this.getActionsToReplay();
            this.bucketIndexBeingReplayed = this.getBucketIndexInTraining();
            this.actionBeingReplayedIndex = 0;
            this.virtualBoatPreview = TrainEvToBucket.createVirtualBoatForStartingState(BucketHelper.getStateFromBucketIndex(this.bucketIndexBeingReplayed));
        }
        if (this.actionsBeingReplayed != null) {
            this.virtualBoatPreview.tick(this.actionsBeingReplayed[this.actionBeingReplayedIndex][0], this.actionsBeingReplayed[this.actionBeingReplayedIndex][1], this.actionsBeingReplayed[this.actionBeingReplayedIndex][2], this.actionsBeingReplayed[this.actionBeingReplayedIndex][3]);
            ++this.actionBeingReplayedIndex;
        }
    }

    public static VirtualBoat createVirtualBoatForStartingState(State startingState) {
        float distance = startingState.distanceToWaypoint();
        float angleToWaypoint = startingState.angleToWaypoint() + 1.5707964f;
        double yawRadians = Math.atan2(0.0, -distance) - (double)angleToWaypoint;
        double x = (double)(-distance) * Math.cos(yawRadians + (double)angleToWaypoint);
        double y = (double)(-distance) * Math.sin(yawRadians + (double)angleToWaypoint);
        double cos = Math.cos(yawRadians);
        double sin = Math.sin(yawRadians);
        double velX = (double)startingState.velX() * cos + (double)startingState.velY() * sin;
        double velY = (double)(-startingState.velX()) * sin + (double)startingState.velY() * cos;
        double yawDegrees = Math.toDegrees(yawRadians);
        double velYawDegrees = Math.toDegrees(startingState.velYaw());
        VirtualBoat virtualBoat = new VirtualBoat();
        virtualBoat.x = (float)x;
        virtualBoat.y = (float)y;
        virtualBoat.yaw = (float)yawDegrees;
        virtualBoat.velX = (float)velX;
        virtualBoat.velY = (float)velY;
        virtualBoat.velYaw = (float)velYawDegrees;
        return virtualBoat;
    }

    public static State getStateOfVirtualBoat(VirtualBoat virtualBoat) {
        double rotationRadians = Math.toRadians(virtualBoat.yaw);
        double fx = Math.cos(rotationRadians);
        double fy = Math.sin(rotationRadians);
        double length = Math.hypot(virtualBoat.x, virtualBoat.y);
        float angleToWaypoint = 0.0f;
        if (length > 0.0) {
            double dx = -virtualBoat.x;
            double dy = -virtualBoat.y;
            double dot = fx * (dx /= length) + fy * (dy /= length);
            double cross = fx * dy - fy * dx;
            angleToWaypoint = (float)Math.atan2(cross, dot);
            angleToWaypoint -= 1.5707964f;
            while ((double)angleToWaypoint < -Math.PI) {
                angleToWaypoint += (float)Math.PI * 2;
            }
            while ((double)angleToWaypoint > Math.PI) {
                angleToWaypoint -= (float)Math.PI * 2;
            }
        }
        double cos = Math.cos(-rotationRadians);
        double sin = Math.sin(-rotationRadians);
        float localVelX = (float)((double)virtualBoat.velX * cos - (double)virtualBoat.velY * sin);
        float localVelY = (float)((double)virtualBoat.velX * sin + (double)virtualBoat.velY * cos);
        return new State((float)length, angleToWaypoint, localVelX, localVelY, (float)Math.toRadians(virtualBoat.velYaw));
    }

    public void train() {
        this.train(0, Integer.MAX_VALUE, true);
    }

    public void train(int minIndex, int maxIndex, boolean skipTrainedInOtherScenarios) {
        System.out.println("Starting AI training for TRAIN_EV_TO_BUCKET.");
        System.out.println("Min index: " + minIndex);
        System.out.println("Max index: " + maxIndex);
        System.out.println("Total buckets: " + BucketHelper.TOTAL_BUCKETS);
        System.out.println("Total buckets (accounting for indexes): " + Math.min(maxIndex - minIndex, BucketHelper.TOTAL_BUCKETS));
        BucketValue[] trainedBuckets = new BucketValue[BucketHelper.TOTAL_BUCKETS];
        boolean[] alreadyTrainedBuckets = new boolean[BucketHelper.TOTAL_BUCKETS];
        File folder = new File(Paths.get("train_ev_to_bucket_data_partial", new String[0]).toAbsolutePath().toString());
        File[] files = folder.listFiles();
        int highestFileNo = 0;
        if (files != null) {
            for (File file : files) {
                String name = file.getName();
                if (!name.endsWith(".json")) continue;
                try {
                    String numberPart = name.substring(0, name.length() - 5);
                    int index = Integer.parseInt(numberPart);
                    if (index > highestFileNo) {
                        highestFileNo = index;
                    }
                    if (index < 0 || index >= alreadyTrainedBuckets.length) continue;
                    alreadyTrainedBuckets[index] = true;
                    if (!skipTrainedInOtherScenarios) continue;
                    Map<Integer, BucketValue> bucketValues = TrainEvToBucket.loadPartialTrainingData(index);
                    for (Map.Entry<Integer, BucketValue> entry : bucketValues.entrySet()) {
                        if (entry.getValue() == null || !entry.getValue().reachesTarget()) continue;
                        alreadyTrainedBuckets[entry.getKey().intValue()] = true;
                    }
                }
                catch (NumberFormatException e) {
                    System.err.println("Skipping non-numeric JSON file: " + name);
                }
            }
        }
        int alreadyTrainedCount = 0;
        for (boolean alreadyTrained : alreadyTrainedBuckets) {
            if (!alreadyTrained) continue;
            ++alreadyTrainedCount;
        }
        System.out.println("Already trained: " + alreadyTrainedCount);
        System.out.println("Training of each scenario complete! Now combining data into bucket table...");
        for (int bucketIndex = 0; bucketIndex < trainedBuckets.length; ++bucketIndex) {
            Map<Integer, BucketValue> bucketValues = TrainEvToBucket.loadPartialTrainingData(bucketIndex);
            if (bucketValues == null) continue;
            for (Map.Entry<Integer, BucketValue> entry : bucketValues.entrySet()) {
                if (trainedBuckets[entry.getKey()] != null && trainedBuckets[entry.getKey()].stepsToTarget() <= entry.getValue().stepsToTarget() && (trainedBuckets[entry.getKey()].reachesTarget() || !entry.getValue().reachesTarget())) continue;
                trainedBuckets[entry.getKey().intValue()] = entry.getValue();
            }
        }
        this.savaTrainingData(trainedBuckets);
        System.out.println("Completed training!");
    }

    public void savaTrainingData(BucketValue[] trainedBuckets) {
        String trainedData = GSON.toJson((Object)trainedBuckets);
        File folder = new File(Paths.get("train_ev_to_bucket_data", new String[0]).toAbsolutePath().toString());
        folder.mkdir();
        File file = new File(folder, class_156.method_44893() + ".json");
        class_156.method_27958().execute(() -> {
            try (PrintWriter out = new PrintWriter(file);){
                out.println(trainedData);
                System.out.println("Saved training data");
            }
            catch (IOException ex) {
                System.out.println("Failed to save training data: " + ex.getMessage());
            }
        });
    }

    public void savePartialTrainingData(int bucketIndex, Map<Integer, BucketValue> trainedBuckets) {
        String trainedData = GSON.toJson(trainedBuckets);
        File folder = new File(Paths.get("train_ev_to_bucket_data_partial", new String[0]).toAbsolutePath().toString());
        folder.mkdir();
        File file = new File(folder, bucketIndex + ".json");
        class_156.method_27958().execute(() -> {
            try (PrintWriter out = new PrintWriter(file);){
                out.println(trainedData);
            }
            catch (IOException ex) {
                System.out.println("Failed to save partial training data: " + ex.getMessage());
            }
        });
    }

    public static Map<Integer, BucketValue> loadPartialTrainingData(int index) {
        Map map;
        File folder = new File(Paths.get("train_ev_to_bucket_data_partial", new String[0]).toAbsolutePath().toString());
        File file = new File(folder, index + ".json");
        if (!file.exists()) {
            return null;
        }
        FileReader reader = new FileReader(file);
        try {
            Type type = new TypeToken<Map<Integer, BucketValue>>(){}.getType();
            map = (Map)GSON.fromJson((Reader)reader, type);
        }
        catch (Throwable throwable) {
            try {
                try {
                    reader.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                System.err.println("Error reading file: " + file.getAbsolutePath());
                e.printStackTrace();
                return null;
            }
        }
        reader.close();
        return map;
    }

    public Map<Integer, BucketValue> trainEvForScenario(int bucketToTrainFrom, boolean[] alreadyTrainedBuckets) {
        float distance;
        boolean right;
        boolean left;
        boolean back;
        boolean forward;
        double[] prediction;
        State state;
        if (alreadyTrainedBuckets[bucketToTrainFrom]) {
            System.out.println("Skipping " + bucketToTrainFrom + " as already trained in another scenario");
            return null;
        }
        State stateToStartFrom = BucketHelper.getStateFromBucketIndex(bucketToTrainFrom);
        VirtualBoat virtualBoatInTraining = TrainEvToBucket.createVirtualBoatForStartingState(stateToStartFrom);
        if (Math.abs(virtualBoatInTraining.velYaw) > 35.0f) {
            System.out.println("Skipping " + bucketToTrainFrom + " as spinning too fast");
            return null;
        }
        if (virtualBoatInTraining.x > 0.0f && virtualBoatInTraining.velX > 4.0f) {
            System.out.println("Skipping " + bucketToTrainFrom + " as being flung away (x+)");
            return null;
        }
        if (virtualBoatInTraining.x < 0.0f && virtualBoatInTraining.velX < -4.0f) {
            System.out.println("Skipping " + bucketToTrainFrom + " as being flung away (x-)");
            return null;
        }
        if (virtualBoatInTraining.y > 0.0f && virtualBoatInTraining.velY > 4.0f) {
            System.out.println("Skipping " + bucketToTrainFrom + " as being flung away (y+)");
            return null;
        }
        if (virtualBoatInTraining.y < 0.0f && virtualBoatInTraining.velY < -4.0f) {
            System.out.println("Skipping " + bucketToTrainFrom + " as being flung away (y-)");
            return null;
        }
        double bestScore = -1.0;
        NeuralNetwork bestNeuralNetwork = new NeuralNetwork(5, new int[]{8, 8}, 4);
        double scoreToMoveOn = 2000.0;
        int bestStepCount = Integer.MAX_VALUE;
        while (bestScore < scoreToMoveOn) {
            double score;
            if ((scoreToMoveOn -= 9.0E-5) <= 1930.0) {
                scoreToMoveOn -= (double)0.001f;
            }
            if (scoreToMoveOn <= 1000.0) {
                scoreToMoveOn -= (double)0.1f;
            }
            VirtualBoat virtualBoatInTraining2 = TrainEvToBucket.createVirtualBoatForStartingState(stateToStartFrom);
            NeuralNetwork neuralNetwork = new NeuralNetwork(bestNeuralNetwork);
            neuralNetwork.mutate();
            int stepsThisEpisode = 0;
            double closestDistance = 3.4028234663852886E38;
            do {
                state = TrainEvToBucket.getStateOfVirtualBoat(virtualBoatInTraining2);
                prediction = neuralNetwork.calc(new double[]{state.distanceToWaypoint(), state.angleToWaypoint(), state.velX(), state.velY(), state.velYaw()});
                forward = prediction[0] > 0.5;
                back = prediction[1] > 0.5;
                left = prediction[2] > 0.5;
                right = prediction[3] > 0.5;
                float posXBefore = virtualBoatInTraining2.x;
                float posYBefore = virtualBoatInTraining2.y;
                virtualBoatInTraining2.tick(forward, back, left, right);
                ++stepsThisEpisode;
                double distance2 = Math.hypot(virtualBoatInTraining2.x, virtualBoatInTraining2.y);
                if (distance2 < closestDistance) {
                    closestDistance = distance2;
                }
                score = -1.0;
                if (distance2 <= 0.25 || distance2 < 14.0 && TrainEvToBucket.distanceToTarget(posXBefore, posYBefore, virtualBoatInTraining2.x, virtualBoatInTraining2.y) < 0.25f) {
                    score = 1000 + (1000 - stepsThisEpisode);
                    if (stepsThisEpisode >= bestStepCount) continue;
                    bestStepCount = stepsThisEpisode;
                    continue;
                }
                if (!(distance2 > 500.0) && stepsThisEpisode <= 1000 && stepsThisEpisode <= bestStepCount) continue;
                score = 500.0 - Math.clamp(closestDistance, 0.0, 500.0) - (double)(Math.abs(virtualBoatInTraining2.velYaw) > 0.3f ? Math.abs(virtualBoatInTraining2.velYaw) * 5.0f : 0.0f);
            } while (!(score >= 0.0));
            if (!(score > bestScore)) continue;
            bestScore = score;
            bestNeuralNetwork = neuralNetwork;
        }
        HashMap<Integer, BucketValue> bucketValues = new HashMap<Integer, BucketValue>();
        VirtualBoat trainedVirtualBoat = TrainEvToBucket.createVirtualBoatForStartingState(stateToStartFrom);
        int stepsFromStart = 0;
        ArrayList<boolean[]> actionsThisEpisode = new ArrayList<boolean[]>();
        boolean successInReplay = false;
        do {
            state = TrainEvToBucket.getStateOfVirtualBoat(trainedVirtualBoat);
            prediction = bestNeuralNetwork.calc(new double[]{state.distanceToWaypoint(), state.angleToWaypoint(), state.velX(), state.velY(), state.velYaw()});
            forward = prediction[0] > 0.5;
            back = prediction[1] > 0.5;
            left = prediction[2] > 0.5;
            right = prediction[3] > 0.5;
            boolean[] action = new boolean[]{forward, back, left, right};
            actionsThisEpisode.add(action);
            bucketValues.put(BucketHelper.getBucketIndexForState(state), new BucketValue(stepsFromStart, action, bestScore >= 1000.0));
            float posXBefore = trainedVirtualBoat.x;
            float posYBefore = trainedVirtualBoat.y;
            trainedVirtualBoat.tick(forward, back, left, right);
            distance = (float)Math.hypot(trainedVirtualBoat.x, trainedVirtualBoat.y);
            ++stepsFromStart;
            if (!(distance <= 0.35f) && (!(distance < 14.0f) || !(TrainEvToBucket.distanceToTarget(posXBefore, posYBefore, trainedVirtualBoat.x, trainedVirtualBoat.y) < 0.35f))) continue;
            successInReplay = true;
        } while (!successInReplay && !(distance > 500.0f) && stepsFromStart <= 1000);
        if (successInReplay) {
            this.setActionsToReplay((boolean[][])actionsThisEpisode.toArray((T[])new boolean[0][]));
            this.setBucketIndexInTraining(bucketToTrainFrom);
        }
        HashMap<Integer, BucketValue> finalBucketValues = new HashMap<Integer, BucketValue>();
        for (Map.Entry entry : bucketValues.entrySet()) {
            finalBucketValues.put((Integer)entry.getKey(), new BucketValue(stepsFromStart - ((BucketValue)entry.getValue()).stepsToTarget(), ((BucketValue)entry.getValue()).action(), ((BucketValue)entry.getValue()).reachesTarget()));
            if (!successInReplay) continue;
            alreadyTrainedBuckets[((Integer)entry.getKey()).intValue()] = true;
        }
        System.out.println("Completed training for scenario " + bucketToTrainFrom + (bestScore >= 1000.0 ? " successfully" : " unsuccessfully") + " (with score: " + bestScore + ")");
        if (successInReplay) {
            this.savePartialTrainingData(bucketToTrainFrom, finalBucketValues);
        }
        return finalBucketValues;
    }

    public static float distanceToTarget(float aX, float aY, float bX, float bY) {
        float abX = bX - aX;
        float abY = bY - aY;
        float apX = -aX;
        float apY = -aY;
        float abLengthSquared = abX * abX + abY * abY;
        float dotProduct = apX * abX + apY * abY;
        float t = abLengthSquared == 0.0f ? 0.0f : dotProduct / abLengthSquared;
        t = Math.max(0.0f, Math.min(1.0f, t));
        float closestX = aX + t * abX;
        float closestY = aY + t * abY;
        return (float)Math.sqrt(closestX * closestX + closestY * closestY);
    }

    public void onRender(WorldRenderContext worldRenderContext) {
        if (this.virtualBoatPreview != null) {
            this.virtualBoatPreview.render(worldRenderContext);
        }
    }
}

