package net.shasankp000.GameAI;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import net.minecraft.class_1799;
import net.shasankp000.GameAI.StateActions;

/* loaded from: input_file:net/shasankp000/GameAI/RLAgent.class */
public class RLAgent {
    private static final double ALPHA = 0.1d;
    private static final double GAMMA = 0.9d;
    private static final double MIN_EPSILON = 0.1d;
    private static final double EPSILON_DECAY_RATE = 0.95d;
    private double epsilon = 1.0d;
    private Map<State, Map<StateActions.Action, Double>> qTable = new HashMap();
    private Random random = new Random();

    public StateActions.Action chooseAction(State state) {
        double nextDouble = this.random.nextDouble();
        System.out.println("Generated random value: " + nextDouble);
        if (nextDouble >= this.epsilon) {
            return (StateActions.Action) this.qTable.getOrDefault(state, new HashMap()).entrySet().stream().max(Map.Entry.comparingByValue()).map((v0) -> {
                return v0.getKey();
            }).orElse(StateActions.Action.STAY);
        }
        System.out.println("Exploring with epsilon: " + this.epsilon);
        return StateActions.Action.values()[this.random.nextInt(StateActions.Action.values().length)];
    }

    public void decayEpsilon() {
        this.epsilon = Math.max(0.1d, this.epsilon * EPSILON_DECAY_RATE);
        System.out.println("Updated epsilon: " + this.epsilon);
    }

    public Map<State, Map<StateActions.Action, Double>> getQTable() {
        return this.qTable;
    }

    public static int calculateReward(int i, int i2, int i3, double d, int i4, double d2, List<class_1799> list, String str, String str2, String str3, int i5, int i6, class_1799 class_1799Var, Map<String, class_1799> map, StateActions.Action action) {
        int i7;
        int i8;
        int i9 = 0;
        if (d > 10.0d) {
            i9 = 0 + 10;
        } else if (d <= 5.0d) {
            i9 = 0 - 10;
            if (action == StateActions.Action.ATTACK) {
                i9 += 15;
            } else if (action == StateActions.Action.STAY) {
                i9 -= 5;
            }
        }
        if (i4 > 15) {
            i7 = i9 + 10;
        } else if (i4 <= 5) {
            i7 = i9 - 20;
            if (action == StateActions.Action.STAY || action == StateActions.Action.USE_ITEM) {
                i7 += 10;
            }
        } else {
            i7 = i9 + 5;
        }
        if (d2 > 10.0d) {
            i7 += 10;
        } else if (d2 <= 5.0d) {
            i7 -= 15;
            if (action == StateActions.Action.MOVE_BACKWARD || action == StateActions.Action.TURN_LEFT || action == StateActions.Action.TURN_RIGHT) {
                i7 += 10;
            }
        }
        int i10 = (str.equalsIgnoreCase("sword") || str.equalsIgnoreCase("bow")) ? i7 + 10 : str.equalsIgnoreCase("shield") ? i7 + 5 : i7 - 5;
        if (class_1799Var.method_7909().toString().equalsIgnoreCase("minecraft:shield")) {
            i10 += 10;
        }
        if (str2.equals("day")) {
            i10 += 5;
        } else if (str2.equals("night")) {
            i10 -= 10;
        }
        boolean z = -1;
        switch (str3.hashCode()) {
            case -949530939:
                if (str3.equals("minecraft:nether")) {
                    z = true;
                    break;
                }
                break;
            case 1104210353:
                if (str3.equals("minecraft:overworld")) {
                    z = false;
                    break;
                }
                break;
            case 1768636814:
                if (str3.equals("minecraft:end")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                i8 = i10 + 0;
                break;
            case true:
                i8 = i10 - 5;
                break;
            case true:
                i8 = i10 - 10;
                break;
            default:
                i8 = i10 - 20;
                break;
        }
        boolean z2 = false;
        boolean z3 = false;
        boolean z4 = false;
        boolean z5 = false;
        for (class_1799 class_1799Var2 : list) {
            if (class_1799Var2.method_7909().toString().equalsIgnoreCase("minecraft:totem_of_undying")) {
                z2 = true;
            } else if (class_1799Var2.method_7909().toString().equalsIgnoreCase("minecraft:golden_apple")) {
                z3 = true;
            } else if (class_1799Var2.method_7909().toString().equalsIgnoreCase("minecraft:water_bucket")) {
                z4 = true;
            } else if (class_1799Var2.method_7909().toString().equalsIgnoreCase("minecraft:bread") || class_1799Var2.method_7909().toString().equalsIgnoreCase("minecraft:cooked_beef")) {
                z5 = true;
            }
        }
        if (z2) {
            i8 += 15;
        }
        if (z3) {
            i8 += 10;
        }
        if (z4) {
            i8 += 5;
        }
        int i11 = z5 ? i8 + 5 : i8 - 10;
        if (i5 <= 6) {
            i11 -= 10;
        } else if (i5 > 16) {
            i11 += 5;
        }
        if (i6 < 60) {
            i11 -= 20;
        } else if (i6 >= 150) {
            i11 += 10;
        }
        Iterator<Map.Entry<String, class_1799>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            if (!it.next().getValue().method_7960()) {
                i11 += 5;
            }
        }
        if (i4 <= 10 && !z2 && action == StateActions.Action.ATTACK && d <= 5.0d) {
            i11 += 20;
        }
        return i11;
    }

    public void updateQValue(State state, StateActions.Action action, double d, State state2) {
        Map<StateActions.Action, Double> computeIfAbsent = this.qTable.computeIfAbsent(state, state3 -> {
            return new HashMap();
        });
        double doubleValue = computeIfAbsent.getOrDefault(action, Double.valueOf(0.0d)).doubleValue();
        computeIfAbsent.put(action, Double.valueOf(doubleValue + (0.1d * ((d + (GAMMA * this.qTable.getOrDefault(state2, new HashMap()).values().stream().max((v0, v1) -> {
            return Double.compare(v0, v1);
        }).orElse(Double.valueOf(0.0d)).doubleValue())) - doubleValue))));
    }

    public void endEpisode() {
        decayEpsilon();
    }
}
