package me.pepperbell.continuity.client.util;

import java.util.EnumMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.locks.StampedLock;

import org.jetbrains.annotations.Unmodifiable;

import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap;
import net.fabricmc.fabric.api.client.rendering.v1.InvalidateRenderStateCallback;
import net.fabricmc.fabric.api.renderer.v1.Renderer;
import net.fabricmc.fabric.api.renderer.v1.mesh.MutableMesh;
import net.fabricmc.fabric.api.renderer.v1.mesh.MutableQuadView;
import net.fabricmc.fabric.api.renderer.v1.mesh.QuadEmitter;
import net.fabricmc.fabric.api.renderer.v1.mesh.QuadTransform;
import net.minecraft.class_1058;
import net.minecraft.class_1087;
import net.minecraft.class_2338;
import net.minecraft.class_2350;
import net.minecraft.class_2680;
import net.minecraft.class_310;
import net.minecraft.class_5819;
import net.minecraft.class_773;
import net.minecraft.class_9891;

public final class SpriteCalculator {
	private static final class_773 MODELS = class_310.method_1551().method_1554().method_4743();

	private static final EnumMap<class_2350, SpriteCache> SPRITE_CACHES = new EnumMap<>(class_2350.class);

	static {
		for (class_2350 direction : class_2350.values()) {
			SPRITE_CACHES.put(direction, new SpriteCache(direction));
		}

		InvalidateRenderStateCallback.EVENT.register(SpriteCalculator::clearCache);
	}

	@Unmodifiable
	public static Set<class_1058> getSprites(class_2680 state, class_2350 face) {
		return SPRITE_CACHES.get(face).getSprites(state);
	}

	public static void clearCache() {
		for (SpriteCache cache : SPRITE_CACHES.values()) {
			cache.clear();
		}
	}

	private static class SpriteCache {
		private final class_2350 face;
		private final Reference2ObjectOpenHashMap<class_2680, Set<class_1058>> spritesMap = new Reference2ObjectOpenHashMap<>();
		private final MutableMesh mutableMesh = Renderer.get().mutableMesh();
		private final CollectingQuadTransform quadTransform;
		private final class_5819 random = class_5819.method_43053();
		private final StampedLock lock = new StampedLock();

		public SpriteCache(class_2350 face) {
			this.face = face;
			quadTransform = new CollectingQuadTransform(face);
		}

		@Unmodifiable
		public Set<class_1058> getSprites(class_2680 state) {
			Set<class_1058> sprites;

			long optimisticReadStamp = lock.tryOptimisticRead();
			if (optimisticReadStamp != 0L) {
				try {
					// This map read could happen at the same time as a map write, so catch any exceptions.
					// This is safe due to the map implementation used, which is guaranteed to not mutate the map during
					// a read.
					sprites = spritesMap.get(state);
					if (sprites != null && lock.validate(optimisticReadStamp)) {
						return sprites;
					}
				} catch (Exception e) {
					//
				}
			}

			long readStamp = lock.readLock();
			try {
				sprites = spritesMap.get(state);
			} finally {
				lock.unlockRead(readStamp);
			}

			if (sprites == null) {
				long writeStamp = lock.writeLock();
				try {
					sprites = spritesMap.get(state);
					if (sprites == null) {
						sprites = calculateSprites(state);
						spritesMap.put(state, sprites);
					}
				} finally {
					lock.unlockWrite(writeStamp);
				}
			}

			return sprites;
		}

		@Unmodifiable
		private Set<class_1058> calculateSprites(class_2680 state) {
			class_1087 model = MODELS.method_3335(state);
			QuadEmitter emitter = mutableMesh.emitter();
			quadTransform.clear();
			emitter.pushTransform(quadTransform);
			random.method_43052(42);
			try {
				model.emitQuads(emitter, class_9891.field_52611, class_2338.field_10980, state, random, cullFace -> false);
			} catch (Exception e) {
				//
			}
			emitter.popTransform();
			Set<class_1058> sprites = quadTransform.result();
			return !sprites.isEmpty() ? sprites : Set.of(model.method_68511());
		}

		public void clear() {
			long writeStamp = lock.writeLock();
			try {
				spritesMap.clear();
				quadTransform.clear();
			} finally {
				lock.unlockWrite(writeStamp);
			}
		}

		private static class CollectingQuadTransform implements QuadTransform {
			private final class_2350 face;
			private final List<class_1058> sprites = new ObjectArrayList<>();

			private CollectingQuadTransform(class_2350 face) {
				this.face = face;
			}

			@Override
			public boolean transform(MutableQuadView quad) {
				if (quad.lightFace() == face) {
					sprites.add(RenderUtil.getSpriteFinder().find(quad));
				}
				return false;
			}

			public void clear() {
				sprites.clear();
			}

			@Unmodifiable
			public Set<class_1058> result() {
				return Set.copyOf(sprites);
			}
		}
	}
}
