package dev.hipposgrumm.armor_trims.util.color;

import com.mojang.datafixers.util.Pair;
import dev.hipposgrumm.armor_trims.Armortrims;
import dev.hipposgrumm.armor_trims.util.ArmortrimsInternalUtils;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Triple;
import org.jetbrains.annotations.Nullable;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import net.minecraft.class_1011;
import net.minecraft.class_1047;
import net.minecraft.class_1058;
import net.minecraft.class_1061;
import net.minecraft.class_1079;
import net.minecraft.class_1080;
import net.minecraft.class_1087;
import net.minecraft.class_1792;
import net.minecraft.class_2960;
import net.minecraft.class_310;
import net.minecraft.class_3298;
import net.minecraft.class_5251;
import net.minecraft.class_5253;
import net.minecraft.class_793;

public class ColorPalette implements class_1061 {
    public static final ColorPalette DEFAULT = new DefaultColorPalette(); // Default color palette, for when all else fails.
    public static final int[] PALETTE_COLORS = {224,192,160,128,96,64,32,0}; // Shades in palette, brightest to darkest.

    final List<ColorFrame> frames;
    final List<ColorFrame> uniqueFrames;
    protected final class_2960 name;
    protected class_1079 meta;
    protected int tick = 0;
    private boolean discarded = false;

    /** Default */
    private ColorPalette() {
        this(new class_2960(Armortrims.MODID,"default"));
        this.meta = class_1079.field_21768;
    }

    /** Single Color */
    private ColorPalette(class_2960 name) {
        this.name = name;
        this.frames = Collections.emptyList();
        this.uniqueFrames = Collections.emptyList();
    }

    ColorPalette(class_2960 name, class_1011 texture, @Nullable class_1079 meta) {
        this.name = name;
        this.frames = new ArrayList<>();
        this.uniqueFrames = new ArrayList<>();
        if (meta == null) meta = class_1079.field_21768;
        this.meta = meta;

        // Read entire horizontal before increasing vertical.
        for (int frameY=0;frameY<texture.method_4323();frameY+=1) for (int frameX=0;frameX<texture.method_4307();frameX+=PALETTE_COLORS.length)  {
            NavigableMap<Integer,Integer> frame = new TreeMap<>(Integer::compareTo);
            for (int x=0;x<PALETTE_COLORS.length;x++) {
                int color = texture.method_4315(x+frameX, frameY);
                color = (color & 0xFF00FF00)      // Keep position of alpha and green.
                        | ((color >> 16) & 0xFF)  // Flip red.
                        | ((color & 0xFF) << 16); // Flip blue.
                frame.put(PALETTE_COLORS[x], color);
            }
            uniqueFrames.add(new ColorFrame(frame));
        }
        arrangeFrames(uniqueFrames);
    }

    protected void arrangeFrames(List<ColorFrame> uniqueFrames) {
        // If there is only one frame, shortcut.
        if (uniqueFrames.size() == 1) {
            frames.add(uniqueFrames.get(0));
            return;
        }

        // Take each frame from above and assign them to a new list, using the metadata.
        forEachFrame((index, time) -> {
            ColorFrame frame = uniqueFrames.get(index);
            if (meta.method_4685()) {
                if (time > 0) {
                    frames.add(frame);
                    if (time > 1) {
                        for (int i=1;i<time;i++) frames.add(null);
                    }
                }
            } else {
                for (int i=0;i<time;i++) frames.add(frame);
            }
        }, uniqueFrames);

        // An empty framelist would likely break the code below.
        if (frames.isEmpty()) return;

        // Resolve each of the interpolated frames
        if (this.meta.method_4685()) {
            Map<Triple<ColorFrame, ColorFrame, Double>, ColorFrame> cache = new HashMap<>();
            int interpolateStart = 0;
            int interpolateLength = 0;
            for (int i=0;i<frames.size();i++) {
                ColorFrame frame = frames.get(i);
                if (frame != null) {
                    // Replace previously null frames with interpolated frames.
                    setInterpolatedFrames(
                            frame,
                            interpolateStart, i,
                            interpolateLength,
                            cache
                    );
                    interpolateStart = i;
                    interpolateLength = 0;
                }
                interpolateLength++;
            }

            // Same as above but wrap to the beginning.
            setInterpolatedFrames(
                    frames.get(0),
                    interpolateStart, 0,
                    interpolateLength,
                    cache
            );

            // Add the interpolated frames as unique frames.
            uniqueFrames.addAll(cache.values());
        }
    }

    private void setInterpolatedFrames(ColorFrame nextFrame, int firstIndex, int secondIndex, int interpolateLength, Map<Triple<ColorFrame, ColorFrame, Double>, ColorFrame> cache) {
        ColorFrame lastFrame = frames.get(firstIndex);
        ColorFrame first = lastFrame;  // First is oldest.
        ColorFrame second = nextFrame; // Second is newest.
        boolean inverse = secondIndex < firstIndex; // Inverse the progress if necessary so that the cache works efficiently.
        if (inverse) { // If the second is older than the first, swap the first and the second.
            first = nextFrame;  // This is now actually the oldest.
            second = lastFrame; // This is now the actual newest.
        }
        for (int j=1;j<interpolateLength;j++) {
            final double progress = (double) ( // Progress used by the cache.
                    inverse ? (interpolateLength-j) : j // If inverse progress, then inverse j over interpolateLength; otherwise, j as normal.
            ) / (double)interpolateLength; // Divide j by interpolateLength
            if (first == second) {
                frames.set(firstIndex+j, first);
            } else {
                frames.set(firstIndex+j, cache.computeIfAbsent(new ImmutableTriple<>(
                        first,                              // First
                        second,                             // Next
                        progress // Progress
                ), data -> ColorFrame.interpolate(
                        lastFrame,                          // First
                        nextFrame,                          // Next
                        progress  // refMin, refMax, Progress
                )));
            }
        }
    }

    public class_2960 name() {
        return name;
    }

    public class_1079 meta() {
        return meta;
    }

    public int get(int tintIndex) {
        if (tintIndex>PALETTE_COLORS[0]) tintIndex = PALETTE_COLORS[0];
        if (frames.isEmpty()) return DEFAULT.get(tintIndex);
        return frames.get(tick).getColor(tintIndex);
    }

    public class_5251 textColor() {
        int color = get(ColorPalette.PALETTE_COLORS[1]);
        return class_5251.method_27717(color);
    }

    /**
     * Apply a list of palettes to a texture.
     * @param colors - List of color palettes to apply to the texture.
     * @param base - The texture to apply the palettes to.
     * @param merged - Whether to merge all the textures into a single texture (for item textures).
     * @return A map of textures for each color palette.
     */
    // NOTE: I've decided that animated textures for trims will not be implemented.
    public static Map<ColorPalette, Pair<class_1011[],class_1079>> apply(List<ColorPalette> colors, class_1011 base, boolean merged) {
        // Create map for textures.
        Map<ColorPalette, Triple<class_1011[],Pair<Integer,Integer>,class_1079>> textures = new HashMap<>();

        // Size of texture
        int width = base.method_4307();
        int height = base.method_4323();

        // Initialize textures before setting them.
        for (ColorPalette color : colors) {
            class_1011[] images;
            Pair<Integer,Integer> sizes = null;
            class_1079 meta = class_1079.field_21768;
            if (merged) { // If merged, create a larger texture for all the frames.
                images = new class_1011[1];
                // Create a size that can accommodate all the textures.
                int size = color.uniqueFrames.size();
                int sizeX=1, sizeY=1;
                // Expand grid until satisfied.
                while (sizeX*sizeY < size) {
                    if (sizeY > sizeX) sizeX++; // Increase width if the height has already increased.
                    else sizeY++;       // Increase height before increasing width.
                }
                sizes = new Pair<>(sizeX, sizeY);
                images[0] = new class_1011(width*sizeX,height*sizeY,true);
                List<class_1080> frames = new ArrayList<>();
                color.forEachFrame((index, time) -> {
                    frames.add(new class_1080(index,time));
                }, false);
                meta = new class_1079(frames, width, height, color.meta.method_4684(), color.meta.method_4685());
            } else {
                images = new class_1011[color.uniqueFrames.size()];
                for (int i=0;i<images.length;i++) images[i] = new class_1011(width, height, true);
            }

            textures.put(color,new ImmutableTriple<>(images,sizes,meta));
        }

        // Modify each texture based on each color palette.
        for (int x=0;x<width;x++) for (int y=0;y<height;y++) {
            // Get pixel from texture. Pixel is in RGBA.
            int pixel = base.method_4315(x,y);
            int alpha = pixel>>>24;
            // If alpha is 0 (completely transparent), skip.
            if (alpha == 0) continue;

            // Get tint from pixel color. Use red as tint.
            int tint = (pixel >> 16) & 0xff;
            for (ColorPalette color:colors) {
                Triple<class_1011[],Pair<Integer,Integer>,?> framesData = textures.get(color);
                int sizeX = 1;
                int sizeY = 1;
                Pair<Integer, Integer> sizes = framesData.getMiddle();
                if (sizes != null) {
                    sizeX = sizes.getFirst();
                    sizeY = sizes.getSecond();
                }
                class_1011[] frames = framesData.getLeft();
                int imageCount = frames.length;
                // If merged use the amount of frames from here instead.
                if (merged) imageCount = color.uniqueFrames.size();
                for (int i=0;i<imageCount;i++) {
                    int col = color.uniqueFrames.get(i).getColor(tint);
                    // Determine alpha of pixel.
                    float alphaMod = alpha/255f;
                    int finalAlpha = (int)((col>>>24)*alphaMod);
                    // If alpha is 0 (completely transparent), skip.
                    if (finalAlpha == 0) continue;

                    col = (col & 0x0000FF00)        // Keep position of green, and remove alpha  we're at it, it's calculated later.
                            | ((col >> 16) & 0xFF)  // Flip red.
                            | ((col & 0xFF) << 16); // Flip blue.
                    col = col | (finalAlpha<<24);   // Add RGB of color with final alpha of texture. Alpha was already removed above.

                    if (merged) frames[0].method_4305(x+((i%sizeX)*width), y+((i/sizeY)*height), col);
                    else frames[i].method_4305(x, y, col);
                }
            }
        }

        return textures.entrySet()
                .stream()
                .collect(Collectors.toMap(
                        Map.Entry::getKey,
                        entry -> new Pair<>(entry.getValue().getLeft(),entry.getValue().getRight()),
                        (x, y) -> y,
                        HashMap::new
                ));
    }

    public void forEachFrame(/*? if >=1.18 {*/class_1079.class_5792/*?} else {*//*BiConsumer<Integer, Integer>*//*?}*/ function, boolean countInterpolatedFrames) {
        ColorFrame last = null;
        int count = 1;
        for (ColorFrame frame:frames) {
            if (frame == last || (frame.interpolated() && !countInterpolatedFrames)) {
                count++;
            } else {
                if (last != null) function.accept(uniqueFrames.indexOf(last), count);
                last = frame;
                count = 1;
            }
        }
        if (last != null) function.accept(uniqueFrames.indexOf(last), count);
    }

    protected void forEachFrame(/*? if >=1.18 {*/class_1079.class_5792/*?} else {*//*BiConsumer<Integer, Integer>*//*?}*/ function, List<ColorFrame> frames) {
        AtomicBoolean ran = new AtomicBoolean(false);
        //? if >=1.18 {
        meta.method_33460((index, time) -> {
            function.accept(index, time);
            ran.set(true);
        });
        //?} else {
        /*if (meta.getFrameCount() > 0) {
            for (int i = 0; i < meta.getFrameCount(); i++) {
                function.accept(meta.getFrameIndex(i), meta.getFrameTime(i));
            }
            ran.set(true);
        }
        *///?}
        if (!ran.get()) {
            for (int i=0;i<frames.size();i++) {
                int time = meta.method_4684();
                if (time > 0) function.accept(i, time);
            }
        }
    }

    void discard() {
        this.discarded = true;
    }

    public boolean discarded() {
        return discarded;
    }

    @Override
    public void method_4622() {
        tick++;
        tick %= frames.size();
    }

    @Override
    public int hashCode() {
        return name.hashCode();
    }

    @Override
    public String toString() {
        return name.toString();
    }

    public static class SingleColorPalette extends ColorPalette {
        private final List<Integer> color = new ArrayList<>();
        private final List<Integer> textColor = new ArrayList<>();

        SingleColorPalette(class_2960 name, class_1792 item) {
            super(name);
            this.meta = processTextureColor(name,item);
        }

        private class_1079 processTextureColor(class_2960 itemId, class_1792 item) {
            class_1079 meta = null;
            try {
                class_1087 bakedModel = class_310.method_1551()
                    .method_1480()
                    .method_4012()
                    .method_3308(item.method_7854());
                class_1058 particleIcon = null;
                if (bakedModel != null) particleIcon = bakedModel.method_4711();
                class_2960 textureLoc = null;
                if (particleIcon != null) textureLoc = particleIcon.method_4598();
                if (textureLoc == null || textureLoc.equals(class_1047.method_4539())) {
                    /*? if >=1.19 {*//*Optional<Resource>*//*?} else {*/class_3298/*?}*/ resource = class_310.method_1551().method_1478().method_14486(new class_2960(itemId.method_12836(),"models/item/"+itemId.method_12832()+".json"));
                    /*? if >=1.19 {*//*if (resource.isPresent())*//*?}*/ {
                        Reader reader = new InputStreamReader(resource./*? if >=1.19 {*//*get().open*//*?} else {*/method_14482/*?}*/(), StandardCharsets.UTF_8);
                        class_793 model = class_793.method_3437(reader);
                        textureLoc = model.method_24077("layer0").method_24147();
                    }
                }
                if (textureLoc == null || textureLoc.equals(class_1047.method_4539())) throw new NullPointerException("Could not find a texture in model of "+itemId+", it might be a non-vanilla model."); // This should compatible with most non-vanilla models anyway unless they're doing something really strange. This is just a failsafe.

                /*? if >=1.19 {*//*Optional<Resource>*//*?} else {*/class_3298/*?}*/ texture = class_310.method_1551().method_1478().method_14486(new class_2960(textureLoc.method_12836(),"textures/"+textureLoc.method_12832()+".png"));
                class_1011 main = class_1011.method_4309(texture./*? if >=1.19 {*//*get().open*//*?} else {*/method_14482/*?}*/());
                meta = texture./*? if >=1.19 {*//*get().metadata().getSection*//*?} else {*/method_14481/*?}*/(class_1079.field_5337)/*? if >=1.19 {*//*.orElse(null)*//*?}*/;
                if (meta == null) meta = class_1079.field_21768;

                int width = meta.method_4687(main.method_4307());
                int height = meta.method_4686(main.method_4323());
                AtomicInteger animSize = new AtomicInteger();
                AtomicInteger totalAnimSize = new AtomicInteger();
                //? if >=1.18 {
                class_1079 lambdaSafeMeta = meta;
                meta.method_33460((index, time) -> {
                    animSize.addAndGet((lambdaSafeMeta.method_4685()?time:1));
                    totalAnimSize.incrementAndGet();
                });
                //?} else {
                /*for (int i=0;i<meta.getFrameCount();i++) {
                    animSize.addAndGet((meta.isInterpolatedFrames()?meta.getFrameTime(i):1));
                    totalAnimSize.incrementAndGet();
                }
                *///?}
                if (totalAnimSize.get() == 0) totalAnimSize.set((main.method_4307() / width) * (main.method_4323() / height));
                final int brightenAmount = 30;

                // Get colors from all frames and assign them to a list, ordered by index.
                for (int frameY=0;frameY<main.method_4323();frameY+=height) for (int frameX=0;frameX<main.method_4307();frameX+=width)  { // Read entire horizontal before increasing vertical.
                    int size = 0;
                    long[] colorVals = new long[]{0, 0, 0, 0};
                    for (int x=0;x<width;x++) for (int y=0;y<height;y++) {
                        int pixel = main.method_4315(frameX+x, frameY+y);
                        int alpha = class_5253.class_5254.method_27762(pixel);
                        if ((alpha >= 25)) {
                            colorVals[0] += alpha;
                            colorVals[1] += class_5253.class_5254.method_27767(pixel);
                            colorVals[2] += class_5253.class_5254.method_27766(pixel);
                            colorVals[3] += class_5253.class_5254.method_27765(pixel);
                            size++;
                        }
                    }
                    colorVals[0] /= size;
                    colorVals[1] = (long) (((float) colorVals[1]/size)+brightenAmount);
                    colorVals[2] = (long) (((float) colorVals[2]/size)+brightenAmount);
                    colorVals[3] = (long) (((float) colorVals[3]/size)+brightenAmount);
                    color.add(class_5253.class_5254.method_27764(
                            (int) colorVals[0],
                            Math.min((int)colorVals[1],255),
                            Math.min((int)colorVals[2],255),
                            Math.min((int)colorVals[3],255)
                    ));
                    float tint = PALETTE_COLORS[1]/(float)PALETTE_COLORS[0];
                    textColor.add(class_5253.class_5254.method_27764(
                            (int) colorVals[0],
                            Math.min((int)(colorVals[1]*tint),255),
                            Math.min((int)(colorVals[2]*tint),255),
                            Math.min((int)(colorVals[3]*tint),255)
                    ));
                }

                Armortrims.LOGGER.debug("Generated palette from {}", itemId);
            } catch (IOException e) {
                if (item/*? if >=1.18.2 {*/.method_7854()/*?}*/.method_31573(Armortrims.TRIM_MATERIALS_TAG)) Armortrims.LOGGER.error("Couldn't generate a palette for {}. Please define a palette for it instead.", itemId);
                else Armortrims.LOGGER.debug("Couldn't generate a palette for {}!", itemId);
            } catch (NullPointerException e) {
                Armortrims.LOGGER.error("Unable to generate palette for {}.", itemId);
            } catch (ClassCastException e) {
                Armortrims.LOGGER.error("Unable to load resultLocation for {}.", itemId);
            } catch (RuntimeException e) {
                Armortrims.LOGGER.error("Cannot find sprite of {} ({}).", itemId, item);
            }
            return meta;
        }

        @Override
        public int get(int tintIndex) {
            if (tintIndex>=PALETTE_COLORS[0]) {
                // If color is first or brighter.
                return color.get(tick);
            } else if (tintIndex==PALETTE_COLORS[1]) {
                // If color is same as textcolor.
                return textColor.get(tick);
            } else {
                // Anything else (probably slow, but shouldn't occur under normal circumstances).
                return (int)(get(PALETTE_COLORS[0])*((float)tintIndex/PALETTE_COLORS[0]));
            }
        }

        @Override
        public void method_4622() {
            tick++;
            tick%=color.size();
        }

        public boolean isValid() {
            return !color.isEmpty();
        }
    }

    private static final class DefaultColorPalette extends ColorPalette {
        @Override
        public int get(int tintIndex) {
            return 0xff000000 | (tintIndex << 16) | (tintIndex << 8) | tintIndex;
        }

        @Override
        public void method_4622() {}
    }
}
