package org.codeberg.zenxarch.zombies.difficulty;

import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import net.fabricmc.fabric.api.attachment.v1.AttachmentType;
import net.minecraft.class_2791;
import net.minecraft.class_2960;
import net.minecraft.class_3218;
import net.minecraft.class_3532;
import org.codeberg.zenxarch.zombies.Zombies;
import org.codeberg.zenxarch.zombies.difficulty.entry.DifficultyEntry;

public final class DifficultyCategory {
  private DifficultyCategory() {
    throw new IllegalStateException("Only static members");
  }

  private static final Map<class_2960, List<DifficultyEntry>> categories =
      new Object2ObjectOpenHashMap<>();

  private static final List<DifficultyCalculationStage> stages = new ObjectArrayList<>();

  public static record DifficultyCalculationStage(
      class_2960 id, BiFunction<Double, Double, Double> mapper) {}

  public static void addDifficultyEntry(class_2960 id, DifficultyEntry entry) {
    categories.computeIfAbsent(id, unused -> new ObjectArrayList<>()).add(entry);
  }

  public static void addDifficultyStage(class_2960 id, BiFunction<Double, Double, Double> mapper) {
    stages.add(new DifficultyCalculationStage(id, mapper));
  }

  public static final AttachmentType<CachedValue> DIFFICULTY =
      CachedValue.createAttachmentType(Zombies.id("difficulty"));

  public static double calculateDifficulty(class_3218 world, class_2791 chunk) {
    if (chunk == null) return 1.0;
    return CachedValue.getOrUpdateValue(
        world, chunk, DIFFICULTY, DifficultyCategory::calculateDifficultyWOCache);
  }

  private static double calculateDifficultyWOCache(class_3218 world, class_2791 chunk) {
    var result = 1.0;
    for (var stage : stages)
      result = stage.mapper.apply(result, calculateDifficulty(world, chunk, stage.id));
    return class_3532.method_15350(result, 0.0, 1.0);
  }

  private static double calculateDifficulty(
      class_3218 world, class_2791 chunk, List<DifficultyEntry> entries) {
    var calculations =
        entries.stream().map(entry -> new DifficultyCalculation(entry, world, chunk)).toList();
    final var weightSum = calculations.stream().mapToInt(DifficultyCalculation::weight).sum();
    return calculations.stream()
        .mapToDouble(entry -> (entry.value * entry.weight) / weightSum)
        .sum();
  }

  public static Set<class_2960> getCategories() {
    return categories.keySet();
  }

  public static double calculateDifficulty(class_3218 world, class_2791 chunk, class_2960 id) {
    if (!categories.containsKey(id)) return 1.0;
    return calculateDifficulty(world, chunk, categories.get(id));
  }

  private static record DifficultyCalculation(double value, int weight) {
    public DifficultyCalculation(DifficultyEntry entry, class_3218 world, class_2791 chunk) {
      this(class_3532.method_15350(entry.calculate(world, chunk), 0.0, 1.0), entry.getWeight());
    }
  }
}
