package com.bawnorton.trimica.client.texture;

import com.bawnorton.trimica.Trimica;
import com.bawnorton.trimica.client.TrimicaClient;
import com.bawnorton.trimica.client.model.TrimModelId;
import com.bawnorton.trimica.item.component.AdditionalTrims;
import com.bawnorton.trimica.item.component.MaterialAdditions;
import com.bawnorton.trimica.trim.TrimmedType;
import com.bawnorton.trimica.util.Lazy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import net.minecraft.class_10186;
import net.minecraft.class_2378;
import net.minecraft.class_2960;
import net.minecraft.class_5455;
import net.minecraft.class_638;
import net.minecraft.class_7924;
import net.minecraft.class_8053;
import net.minecraft.class_8054;
import net.minecraft.class_8056;
import net.minecraft.class_9473;

public final class RuntimeTrimAtlases {
	private final Map<class_8054, Map<class_10186.class_10190, Lazy<RuntimeTrimAtlas>>> equipmentAtlases = new HashMap<>();
	private final Map<class_8054, Lazy<RuntimeTrimAtlas>> itemAtlases = new HashMap<>();
	private final Map<class_8054, Lazy<RuntimeTrimAtlas>> shieldAtlases = new HashMap<>();

	private final List<Consumer<RuntimeTrimAtlas>> modelAtlasModifiedListeners = new ArrayList<>();

	public void init(class_5455 registryAccess) {
		equipmentAtlases.values().forEach(Map::clear);
		equipmentAtlases.clear();
		itemAtlases.clear();
		shieldAtlases.clear();

		class_2378<class_8054> materials = registryAccess.method_46759(class_7924.field_42083).orElseThrow();

		for (class_8054 material : materials) {
			for (class_10186.class_10190 layerType : class_10186.class_10190.values()) {
				createEquipmentAtlas(registryAccess, material, layerType);
			}
			createItemAtlas(registryAccess, material);
			createShieldAtlas(registryAccess, material);
		}
	}

	private void resetFrames() {
		for (Map<class_10186.class_10190, Lazy<RuntimeTrimAtlas>> layerBasedAtlases : equipmentAtlases.values()) {
			for (Lazy<RuntimeTrimAtlas> atlas : layerBasedAtlases.values()) {
				if (atlas.isPresent()) {
					atlas.get().resetFrames();
				}
			}
		}
	}

	private Lazy<RuntimeTrimAtlas> createEquipmentAtlas(class_5455 registryAccess, class_8054 material, class_10186.class_10190 layerType) {
		class_2378<class_8054> materials = registryAccess.method_46759(class_7924.field_42083).orElseThrow();
		class_2378<class_8056> patterns = registryAccess.method_46759(class_7924.field_42082).orElseThrow();
		String materialId = Trimica.getMaterialRegistry().getSuffix(material);
		return new Lazy<>(() -> new RuntimeTrimAtlas(
				Trimica.rl("%s/%s.png".formatted(materialId, layerType.method_15434())),
				new TrimArmourSpriteFactory(layerType),
				(p) -> new class_8053(materials.method_47983(material), patterns.method_47983(p)),
				atlas -> {
					resetFrames();
					for (Consumer<RuntimeTrimAtlas> listener : modelAtlasModifiedListeners) {
						listener.accept(atlas);
					}
				}
		));
	}

	private Lazy<RuntimeTrimAtlas> createItemAtlas(class_5455 registryAccess, class_8054 material) {
		class_2378<class_8054> materials = registryAccess.method_46759(class_7924.field_42083).orElseThrow();
		class_2378<class_8056> patterns = registryAccess.method_46759(class_7924.field_42082).orElseThrow();
		String materialId = Trimica.getMaterialRegistry().getSuffix(material);
		return new Lazy<>(() -> new RuntimeTrimAtlas(
				Trimica.rl("%s/item.png".formatted(materialId)),
				new TrimItemSpriteFactory(),
				(p) -> new class_8053(materials.method_47983(material), patterns.method_47983(p)),
				atlas -> {
					TrimicaClient.getItemModelFactory().clearModels();
					for (Consumer<RuntimeTrimAtlas> listener : modelAtlasModifiedListeners) {
						listener.accept(atlas);
					}
				}
		));
	}

	private Lazy<RuntimeTrimAtlas> createShieldAtlas(class_5455 registryAccess, class_8054 material) {
		class_2378<class_8054> materials = registryAccess.method_46759(class_7924.field_42083).orElseThrow();
		class_2378<class_8056> patterns = registryAccess.method_46759(class_7924.field_42082).orElseThrow();
		String materialId = Trimica.getMaterialRegistry().getSuffix(material);
		return new Lazy<>(() -> new RuntimeTrimAtlas(
				Trimica.rl("%s/shield.png".formatted(materialId)),
				new TrimShieldSpriteFactory(),
				(p) -> new class_8053(materials.method_47983(material), patterns.method_47983(p)),
				atlas -> {
					for (Consumer<RuntimeTrimAtlas> listener : modelAtlasModifiedListeners) {
						listener.accept(atlas);
					}
				}
		));
	}

	public RuntimeTrimAtlas getEquipmentAtlas(class_638 level, class_8054 material, class_10186.class_10190 layerType) {
		return equipmentAtlases.computeIfAbsent(material, k -> new HashMap<>())
				.computeIfAbsent(layerType, k -> createEquipmentAtlas(level.method_30349(), material, layerType))
				.get();
	}

	public RuntimeTrimAtlas getItemAtlas(class_638 level, class_8054 material) {
		return itemAtlases.computeIfAbsent(material, k -> createItemAtlas(level.method_30349(), k)).get();
	}

	public RuntimeTrimAtlas getShieldAtlas(class_638 level, class_8054 material) {
		return shieldAtlases.computeIfAbsent(material, k -> createShieldAtlas(level.method_30349(), k)).get();
	}

	public List<DynamicTrimTextureAtlasSprite> getShieldSprites(class_638 level, class_9473 getter) {
		List<DynamicTrimTextureAtlasSprite> sprites = new ArrayList<>();
		List<class_8053> trims = AdditionalTrims.getAllTrims(getter);
		for (class_8053 trim : trims) {

			TrimModelId trimModelId = TrimModelId.fromTrim(TrimmedType.SHIELD, trim, null);
			class_2960 overlayLocation = trimModelId.asSingle();
			if (MaterialAdditions.enableMaterialAdditions) {
				MaterialAdditions addition = getter.method_58695(MaterialAdditions.TYPE, MaterialAdditions.NONE);
				overlayLocation = addition.apply(overlayLocation);
			}
			sprites.add(getShieldAtlas(level, trim.comp_3179().comp_349()).getSprite(getter, trim.comp_3180().comp_349(), overlayLocation));
		}
		return sprites;
	}

	public void addModelAtlasModifiedListener(Consumer<RuntimeTrimAtlas> listener) {
		modelAtlasModifiedListeners.add(listener);
	}

	public void clear() {
		equipmentAtlases.forEach((pattern, lazyMap) ->
				lazyMap.forEach((layer, lazy) -> lazy.ifPresent(RuntimeTrimAtlas::clear))
		);
		itemAtlases.forEach((pattern, lazy) -> lazy.ifPresent(RuntimeTrimAtlas::clear));
		shieldAtlases.forEach((pattern, lazy) -> lazy.ifPresent(RuntimeTrimAtlas::clear));
		equipmentAtlases.clear();
		itemAtlases.clear();
		shieldAtlases.clear();
	}

	public interface TrimFactory {
		class_8053 create(class_8056 pattern);
	}
}