package com.bawnorton.trimica.client.palette;

import com.bawnorton.trimica.Trimica;
import com.bawnorton.trimica.client.mixin.accessor.*;
import com.bawnorton.trimica.client.colour.ColourGroup;
import com.bawnorton.trimica.client.colour.ColourHSB;
import com.bawnorton.trimica.client.colour.OkLabHelper;
import com.bawnorton.trimica.trim.TrimMaterialRuntimeRegistry;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.*;
import net.minecraft.class_1011;
import net.minecraft.class_10394;
import net.minecraft.class_10439;
import net.minecraft.class_10442;
import net.minecraft.class_10447;
import net.minecraft.class_10539;
import net.minecraft.class_1058;
import net.minecraft.class_10809;
import net.minecraft.class_1792;
import net.minecraft.class_2378;
import net.minecraft.class_2960;
import net.minecraft.class_310;
import net.minecraft.class_3300;
import net.minecraft.class_5321;
import net.minecraft.class_638;
import net.minecraft.class_777;
import net.minecraft.class_7923;
import net.minecraft.class_7924;
import net.minecraft.class_8054;
import net.minecraft.class_9848;

public final class TrimPaletteGenerator {
	private static final Map<class_8054, TrimPalette> TRIM_PALETTES = new HashMap<>();
	private static final Map<class_8054, TrimPalette> BUILT_IN_PALETTES = new HashMap<>();

	public @NotNull TrimPalette generatePalette(class_8054 material, class_5321<class_10394> assetKey) {
		class_1792 materialProvider = Trimica.getMaterialRegistry().guessMaterialProvider(material);
		if (materialProvider == null) {
			return generatePaletteFromBuiltIn(material, assetKey);
		}
		if (!TrimMaterialRuntimeRegistry.enableTrimEverything) {
			Trimica.LOGGER.warn("Trim palette generation is disabled, cannot generate palette for material: {}", class_7923.field_41178.method_10221(materialProvider));
			return TrimPalette.DISABLED;
		}
		class_10442 modelResolver = class_310.method_1551().method_65386();
		class_10439 model = ((ItemModelResolverAccessor) modelResolver).trimica$modelGetter().apply(class_7923.field_41178.method_10221(materialProvider));
		return generatePaletteFromModel(material, assetKey, model);
	}

	private @NotNull TrimPalette generatePaletteFromBuiltIn(class_8054 material, class_5321<class_10394> assetKey) {
		return BUILT_IN_PALETTES.computeIfAbsent(material, k -> {
			List<Integer> colours = getColoursFromBuiltIn(material, assetKey);
			if (colours.isEmpty()) return TrimPalette.MISSING;

			return new TrimPalette(colours, true);
		});
	}

	private List<Integer> getColoursFromBuiltIn(class_8054 material, class_5321<class_10394> assetKey) {
		class_310 minecraft = class_310.method_1551();
		class_3300 resourceManager = minecraft.method_1478();
		class_638 level = minecraft.field_1687;
		if (level == null) return List.of();

		class_2378<class_8054> lookup = level.method_30349().method_46759(class_7924.field_42083).orElse(null);
		if (lookup == null) return List.of();

		class_2960 materialId = lookup.method_10221(material);
		if (materialId == null) {
			// Sometimes for some reason the instance of the material object is changed between client load and render
			// I have no idea why and I can't fix it so have a workaround
			materialId = lookup.method_29722()
					.stream()
					.filter(e -> e.getValue().equals(material))
					.findFirst()
					.map(Map.Entry::getKey)
					.map(class_5321::method_29177)
					.orElse(null);
			if (materialId == null) {
				return List.of();
			}
		}

		String suffix = Trimica.getMaterialRegistry().getSuffix(material, assetKey);
		try (class_10539 contents = class_10539.method_65871(resourceManager, materialId.method_45136("textures/trims/color_palettes/%s.png".formatted(suffix)))) {
			class_1011 image = contents.comp_3447();
			return extractColoursFromBuiltIn(image);
		} catch (IOException e) {
			Trimica.LOGGER.error("Failed to load trim palette texture", e);
			return List.of();
		}
	}

	private List<Integer> extractColoursFromBuiltIn(class_1011 builtInImage) {
		int width = builtInImage.method_4307();
		int height = builtInImage.method_4323();
		if (width != TrimPalette.PALETTE_SIZE || height != 1) {
			return List.of();
		}
		List<Integer> colours = new ArrayList<>(TrimPalette.PALETTE_SIZE);
		for (int x = 0; x < width; x++) {
			int colour = builtInImage.method_61940(x, 0);
			colours.add(colour);
		}
		return colours;
	}

	private TrimPalette generatePaletteFromModel(class_8054 material, class_5321<class_10394> assetKey, class_10439 model) {
		return TRIM_PALETTES.computeIfAbsent(material, key -> {
			List<Integer> colours = getColoursFromModel(model);
			if (colours.isEmpty()) {
				Trimica.LOGGER.warn("Trim palette colour could of determined for {}", Trimica.getMaterialRegistry().getSuffix(material, assetKey));
				return TrimPalette.DEFAULT;
			}
			colours = getDominantColours(colours);
			colours = sortPalette(colours);
			colours = stretchPalette(colours);
			return new TrimPalette(colours);
		});
	}

	private List<Integer> getColoursFromModel(class_10439 model) {
		return switch (model) {
			case BlockModelWrapperAccessor blockModelWrapperAccessor ->
					getColoursFromQuads(blockModelWrapperAccessor.trimica$quads());
			case SelectItemModelAccessor selectItemModelAccessor ->
					getColoursFromModel(selectItemModelAccessor.trimica$models().get(null, null));
			case SpecialModelWrapperAccessor specialModelWrapperAccessor -> {
				class_10809 properties = specialModelWrapperAccessor.trimica$properties();
				int[] colours = extractColours(properties.comp_3767());
				yield Arrays.stream(colours).boxed().toList();
			}
			case CompositeModelAccessor compositeModelAccessor ->
					getColoursFromModel(compositeModelAccessor.trimica$models().getFirst());
			case ConditionalItemModelAccessor conditionalItemModelAccessor ->
					getColoursFromModel(conditionalItemModelAccessor.trimica$onFalse());
			case RangeSelectItemModelAccessor rangeSelectItemModelAccessor -> {
				class_10439 fallback = rangeSelectItemModelAccessor.trimica$fallback();
				if (!(fallback instanceof class_10447)) yield getColoursFromModel(fallback);

				class_10439[] models = rangeSelectItemModelAccessor.trimica$models();
				if (models.length > 0) {
					yield getColoursFromModel(models[0]);
				}
				yield Collections.emptyList();
			}
			case null -> Collections.emptyList();
			default -> {
				Trimica.LOGGER.warn("Cannot extract colours from unknown item model type: {}", model.getClass().getName());
				yield Collections.emptyList();
			}
		};
	}

	private List<Integer> getDominantColours(List<Integer> colours) {
		List<ColourHSB> hsbColours = ColourHSB.fromARGB(colours);

		List<ColourGroup> groups = new ArrayList<>();
		for (ColourHSB colour : hsbColours) {
			boolean foundGroup = false;
			for (ColourGroup group : groups) {
				if (group.isSimilar(colour)) {
					group.addMember(colour);
					foundGroup = true;
					break;
				}
			}
			if (!foundGroup) {
				groups.add(new ColourGroup(colour));
			}
		}
		Collections.sort(groups);
		List<ColourHSB> dominantColours = new ArrayList<>();
		int count = 0;
		for (ColourGroup group : groups) {
			if (count < TrimPalette.PALETTE_SIZE) {
				dominantColours.add(group.getRepresentative());
				count++;
			} else {
				break;
			}
		}
		List<Integer> dominantRGB = new ArrayList<>();
		for (ColourHSB colour : dominantColours) {
			dominantRGB.add(colour.rgb());
		}
		return dominantRGB;
	}

	private List<Integer> sortPalette(List<Integer> colours) {
		return ColourHSB.fromARGB(colours)
				.stream()
				.sorted()
				.map(ColourHSB::rgb)
				.toList();
	}

	private List<Integer> stretchPalette(List<Integer> palette) {
		int size = palette.size();
		int targetSize = TrimPalette.PALETTE_SIZE;
		if (size >= targetSize) {
			return palette;
		}

		List<double[]> oklabPalette = OkLabHelper.rgbToOklab(palette);
		List<double[]> stretchedOKLab = OkLabHelper.strechOkLab(targetSize, size, oklabPalette);
		return OkLabHelper.okLabToRgb(stretchedOKLab);
	}

	private @NotNull List<Integer> getColoursFromQuads(List<class_777> quads) {
		List<Integer> colours = new ArrayList<>(quads.size() * 16 * 16);
		for (class_777 bakedQuad : quads) {
			int[] colourData = extractColours(bakedQuad.comp_3724());
			for (int colour : colourData) {
				colours.add(colour);
			}
		}
		return colours.stream().filter(i -> i != 0).toList();
	}

	@SuppressWarnings("resource")
	private int[] extractColours(class_1058 sprite) {
		class_1011 spriteImage = ((SpriteContentsAccessor) sprite.method_45851()).trimica$originalImage();
		int width = spriteImage.method_4307();
		int height = spriteImage.method_4323();

		int[] colourData = new int[width * height];

		for (int x = 0; x < width; x++) {
			for (int y = 0; y < height; y++) {
				int argb = spriteImage.method_61940(x, y);
				int alpha = class_9848.method_61320(argb);
				if (alpha == 0) {
					continue;
				}

				int red = class_9848.method_61327(argb);
				int green = class_9848.method_61329(argb);
				int blue = class_9848.method_61331(argb);
				int packed = red << 16 | green << 8 | blue;
				colourData[x + y * width] = packed;
			}
		}

		return colourData;
	}

	public void clear() {
		TRIM_PALETTES.clear();
		BUILT_IN_PALETTES.clear();
	}
}