package com.zurrtum.create.client.flywheel.backend.glsl;

import com.zurrtum.create.client.flywheel.backend.compile.FlwPrograms;
import com.zurrtum.create.client.flywheel.lib.util.StringUtil;
import org.jetbrains.annotations.VisibleForTesting;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.*;
import net.minecraft.class_2960;
import net.minecraft.class_3298;
import net.minecraft.class_3300;

/**
 * The main object for loading and parsing source files.
 */
public class ShaderSources {
    public static final String SHADER_DIR = "flywheel/";

    @VisibleForTesting
    protected final Map<class_2960, LoadResult> cache;

    public ShaderSources(class_3300 manager) {
        var sourceFinder = new SourceFinder(manager);

        long loadStart = System.nanoTime();
        manager.method_14488("flywheel", ShaderSources::isShader).forEach(sourceFinder::rootLoad);

        long loadEnd = System.nanoTime();

        FlwPrograms.LOGGER.info("Loaded {} shader sources in {}", sourceFinder.results.size(), StringUtil.formatTime(loadEnd - loadStart));

        this.cache = sourceFinder.results;
    }

    private static class_2960 locationWithoutFlywheelPrefix(class_2960 loc) {
        return class_2960.method_60655(loc.method_12836(), loc.method_12832().substring(SHADER_DIR.length()));
    }

    public LoadResult find(class_2960 location) {
        return cache.computeIfAbsent(location, loc -> new LoadResult.Failure(new LoadError.ResourceError(loc)));
    }

    public SourceFile get(class_2960 location) {
        return find(location).unwrap();
    }

    private static boolean isShader(class_2960 loc) {
        var path = loc.method_12832();
        return path.endsWith(".glsl") || path.endsWith(".vert") || path.endsWith(".frag") || path.endsWith(".comp");
    }

    private static class SourceFinder {
        private final Deque<class_2960> findStack = new ArrayDeque<>();
        private final Map<class_2960, LoadResult> results = new HashMap<>();
        private final class_3300 manager;

        public SourceFinder(class_3300 manager) {
            this.manager = manager;
        }

        public void rootLoad(class_2960 loc, class_3298 resource) {
            var strippedLoc = locationWithoutFlywheelPrefix(loc);

            if (results.containsKey(strippedLoc)) {
                // Some other source already #included this one.
                return;
            }

            this.results.put(strippedLoc, readResource(strippedLoc, resource));
        }

        public LoadResult recursiveLoad(class_2960 location) {
            if (findStack.contains(location)) {
                // Make a copy of the find stack with the offending location added on top to show the full path.
                findStack.addLast(location);
                var copy = List.copyOf(findStack);
                findStack.removeLast();
                return new LoadResult.Failure(new LoadError.CircularDependency(location, copy));
            }
            findStack.addLast(location);

            LoadResult out = _find(location);

            findStack.removeLast();
            return out;
        }

        private LoadResult _find(class_2960 location) {
            // Can't use computeIfAbsent because mutual recursion causes ConcurrentModificationExceptions
            var out = results.get(location);
            if (out == null) {
                out = load(location);
                results.put(location, out);
            }
            return out;
        }

        private LoadResult load(class_2960 loc) {
            return manager.method_14486(loc.method_45138(SHADER_DIR)).map(resource -> readResource(loc, resource))
                .orElseGet(() -> new LoadResult.Failure(new LoadError.ResourceError(loc)));
        }

        private LoadResult readResource(class_2960 loc, class_3298 resource) {
            try (InputStream stream = resource.method_14482()) {
                String sourceString = new String(stream.readAllBytes(), StandardCharsets.UTF_8);
                return SourceFile.parse(this::recursiveLoad, loc, sourceString);
            } catch (IOException e) {
                return new LoadResult.Failure(new LoadError.IOError(loc, e));
            }
        }
    }
}
