package com.petrolpark.core.codec;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;

import com.mojang.datafixers.util.Pair;
import com.mojang.serialization.DataResult;
import com.mojang.serialization.DynamicOps;
import com.mojang.serialization.Lifecycle;
import com.mojang.serialization.ListBuilder;
import com.mojang.serialization.codecs.ListCodec;

import net.minecraft.util.Unit;

/**
 * Copy of {@link ListCodec} that accepts a context object when encoding and decoding
 */
public record ContextualListCodec<CONTEXT, E>(ContextualCodec<CONTEXT, E> elementCodec, int minSize, int maxSize) implements ContextualCodec<CONTEXT, List<E>> {
    
    private <R> DataResult<R> createTooShortError(final int size) {
        return DataResult.error(() -> "List is too short: " + size + ", expected range [" + minSize + "-" + maxSize + "]");
    };

    private <R> DataResult<R> createTooLongError(final int size) {
        return DataResult.error(() -> "List is too long: " + size + ", expected range [" + minSize + "-" + maxSize + "]");
    };

    @Override
    public <T> DataResult<T> encode(final List<E> input, final CONTEXT context, final DynamicOps<T> ops, final T prefix) {
        if (input.size() < minSize) return createTooShortError(input.size());
        if (input.size() > maxSize) return createTooLongError(input.size());
        final ListBuilder<T> builder = ops.listBuilder();
        for (final E element : input) builder.add(elementCodec.encodeStart(ops, context, element));
        return builder.build(prefix);
    }

    @Override
    public <T> DataResult<Pair<List<E>, T>> decode(final DynamicOps<T> ops, final CONTEXT context, final T input) {
        return ops.getList(input).setLifecycle(Lifecycle.stable()).flatMap(stream -> {
            final DecoderState<T> decoder = new DecoderState<>(ops, context);
            stream.accept(decoder::accept);
            return decoder.build();
        });
    }

    @Override
    public String toString() {
        return "ListCodec[" + elementCodec + ']';
    }

    private class DecoderState<T> {
        private static final DataResult<Unit> INITIAL_RESULT = DataResult.success(Unit.INSTANCE, Lifecycle.stable());

        private final DynamicOps<T> ops;
        private final CONTEXT context;
        private final List<E> elements = new ArrayList<>();
        private final Stream.Builder<T> failed = Stream.builder();
        private DataResult<Unit> result = INITIAL_RESULT;
        private int totalCount;

        private DecoderState(final DynamicOps<T> ops, final CONTEXT context) {
            this.ops = ops;
            this.context = context;
        }

        public void accept(final T value) {
            totalCount++;
            if (elements.size() >= maxSize) {
                failed.add(value);
                return;
            }
            final DataResult<Pair<E, T>> elementResult = elementCodec.decode(ops, context, value);
            elementResult.error().ifPresent(error -> failed.add(value));
            elementResult.resultOrPartial().ifPresent(pair -> elements.add(pair.getFirst()));
            result = result.apply2stable((result, element) -> result, elementResult);
        };

        public DataResult<Pair<List<E>, T>> build() {
            if (elements.size() < minSize) {
                return createTooShortError(elements.size());
            }
            final T errors = ops.createList(failed.build());
            final Pair<List<E>, T> pair = Pair.of(List.copyOf(elements), errors);
            if (totalCount > maxSize) {
                result = createTooLongError(totalCount);
            }
            return result.map(ignored -> pair).setPartial(pair);
        };
    };
};
