package cz.yorick.codec;

import com.google.common.collect.ImmutableMap;
import com.mojang.datafixers.util.Pair;
import com.mojang.serialization.Codec;
import com.mojang.serialization.DataResult;
import com.mojang.serialization.MapCodec;
import cz.yorick.SimpleResourcesCommon;
import cz.yorick.api.codec.annotations.OptionalField;
import cz.yorick.api.codec.*;
import cz.yorick.api.codec.annotations.FieldId;
import cz.yorick.api.codec.annotations.Ignore;
import cz.yorick.api.codec.annotations.IncludeParent;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import net.minecraft.class_1299;
import net.minecraft.class_1792;
import net.minecraft.class_1799;
import net.minecraft.class_2248;
import net.minecraft.class_2960;
import net.minecraft.class_5699;
import net.minecraft.class_7923;

public class ClassFieldsReflectionCodec<C, T extends C> {
    private static final ImmutableMap<Class<?>, Codec<?>> DEFAULT_CODECS = ImmutableMap.<Class<?>, Codec<?>>builder()
            //basic java classes
            .put(boolean.class, Codec.BOOL)
            .put(Boolean.class, Codec.BOOL)
            .put(byte.class, Codec.BYTE)
            .put(Byte.class, Codec.BYTE)
            .put(int.class, Codec.INT)
            .put(Integer.class, Codec.INT)
            .put(float.class, Codec.FLOAT)
            .put(Float.class, Codec.FLOAT)
            .put(double.class, Codec.DOUBLE)
            .put(Double.class, Codec.DOUBLE)
            .put(long.class, Codec.LONG)
            .put(Long.class, Codec.LONG)
            .put(String.class, Codec.STRING)
            //minecraft's registries
            .put(class_1792.class, class_7923.field_41178.method_39673())
            .put(class_1299.class, class_7923.field_41177.method_39673())
            .put(class_2248.class, class_7923.field_41175.method_39673())
            //extra minecraft classes
            .put(class_2960.class, class_2960.field_25139)
            .put(class_1799.class, class_1799.field_24671)
            .build();
    private final Supplier<T> defaultFactory;
    private final Map<Class<?>, Codec<?>> extraCodecs;
    private final Map<String, Codec<?>> codecOverwrites;
    private final LinkedHashMap<String, SerializableField> classFields;
    private final Function<T, DataResult<T>> postProcessor;
    ClassFieldsReflectionCodec(Class<?> clazz, Supplier<T> defaultFactory, Map<Class<?>, Codec<?>> extraCodecs, Map<String, Codec<?>> codecOverwrites, Function<T, DataResult<T>> postProcessor) {
        if(clazz.isRecord()) {
            throw new IllegalArgumentException("ClassFieldsCodec does not accept records since they are immutable, if the class has to be a record you need to write your own codec");
        }

        this.defaultFactory = defaultFactory;
        this.extraCodecs = ImmutableMap.copyOf(extraCodecs);
        this.codecOverwrites = ImmutableMap.copyOf(codecOverwrites);
        this.classFields = getSerializableFields(clazz);
        this.postProcessor = postProcessor;
    }

    private LinkedHashMap<String, SerializableField> getSerializableFields(Class<?> clazz) {
        LinkedHashMap<String, SerializableField> allFields = getDeclaredSerializableFields(clazz);
        if(clazz.getAnnotation(IncludeParent.class) != null) {
            getSerializableFields(clazz.getSuperclass()).forEach((id, field) -> {
                if (allFields.containsKey(id)) {
                    Field newField = field.field();
                    Field prevField = allFields.get(id).field();
                    throw new IllegalArgumentException("Duplicate field id '" + id + "' found!" +
                            " Field '" + newField.getName() + "' declared by class '" + newField.getDeclaringClass().getName() + "' has the same id as the previously specified" +
                            " field '" + prevField.getName() + "' declared by class '" + prevField.getDeclaringClass().getName() + "', either change one of the fields names or use the @FieldId or @Ignore annotation");
                }

                allFields.put(id, field);
            });
        }

        return allFields;
    }

    private LinkedHashMap<String, SerializableField> getDeclaredSerializableFields(Class<?> clazz) {
       return Arrays.stream(clazz.getDeclaredFields())
                .filter(this::shouldSerialize)
                .peek(field -> field.setAccessible(true))
                .map(this::getFieldEntry)
                .collect(Collectors.toMap(Pair::getFirst, Pair::getSecond, (field1, field2) -> {
                    throw new RuntimeException("Fields with matching ids found in class '" + clazz.getName() + "', " +
                            "field '" + field1.field().getName() + "' and '" + field2.field().getName() + "' have the same id!");
                }, LinkedHashMap::new));
    }

    private boolean shouldSerialize(Field field) {
        return !Modifier.isStatic(field.getModifiers()) && field.getAnnotation(Ignore.class) == null;
    }

    private Pair<String, SerializableField> getFieldEntry(Field field) {
        String fieldId = getFieldId(field);
        boolean required = field.getAnnotation(OptionalField.class) == null;

        Codec<?> overwriteCodec = this.codecOverwrites.get(fieldId);
        if(overwriteCodec != null) {
            return Pair.of(fieldId, new SerializableField(field, overwriteCodec, required));
        }

        Class<?> fieldClass = field.getType();
        Codec<?> codec = this.extraCodecs.get(fieldClass);
        if(codec != null) {
            return Pair.of(fieldId, new SerializableField(field, codec, required));
        }

        Codec<?> defaultCodec = DEFAULT_CODECS.get(fieldClass);
        if(defaultCodec != null) {
            return Pair.of(fieldId, new SerializableField(field, defaultCodec, required));
        }

        //try to create a generic enum codec
        if(fieldClass.isEnum()) {
            return Pair.of(fieldId, new SerializableField(field, EnumCodec.of(fieldClass.asSubclass(Enum.class)), required));
        }

        throw new IllegalArgumentException("Could not get codec for field '" + field.getName() + "' no codec registered for class " + fieldClass.getName() + " or field id '" + fieldId + "'");
    }

    private String getFieldId(Field field) {
        FieldId fieldId = field.getAnnotation(FieldId.class);
        if(fieldId == null) {
            return field.getName();
        }

        if(fieldId.id().equals("")) {
            throw new IllegalArgumentException("Field '" + field.getName() +"' is marked with @FieldId(id = \"\"), the name of the field cannot be empty!");
        }

        return fieldId.id();
    }

    private Codec<?> getFieldCodec(String fieldName) {
        SerializableField field = this.classFields.get(fieldName);
        if(field == null) {
            throw new IllegalArgumentException("Attempted to get a codec for an unknown field '" + fieldName + "' - this should be filtered out by DelegatedDispatchedMapCodec and never happen!");
        }

        return field.codec();
    }

    private DataResult<T> createWithValues(Map<String, Object> values) {
        //get the missing keys by removing the received keys
        //from the field keys.
        //key set reflects its changes into the map so get a copy
        Set<String> missingKeys = new HashSet<>(this.classFields.keySet());
        missingKeys.removeAll(values.keySet());
        for (String missingKey : missingKeys) {
            //if the key is required, throw an exception
            if(this.classFields.get(missingKey).required()) {
                return DataResult.error(() -> "Missing a required key: '" + missingKey + "'");
            }
        }

        T instance = this.defaultFactory.get();
        for (Map.Entry<String, Object> entry : values.entrySet()) {
            SerializableField serializableField = this.classFields.get(entry.getKey());
            if(serializableField == null) {
                return DataResult.error(() -> "Key '" + entry.getKey() + "' does not represent a valid field!");
            }

            serializableField.set(instance, entry.getValue());
        }

        return this.postProcessor.apply(instance);
    }

    private DataResult<Map<String, Object>> getValues(T instance) {
        LinkedHashMap<String, Object> values = new LinkedHashMap<>();
        for (Map.Entry<String, SerializableField> entry : this.classFields.entrySet()) {
            Object value = entry.getValue().get(instance);
            if(value == null) {
                return DataResult.error(() -> "Cannot serialize the field '" + entry.getKey() + "' because its value is null");
            }

            values.put(entry.getKey(), value);
        }

        return DataResult.success(values);
    }

    public static<C, T extends C> Codec<T> of(Class<C> clazz, Supplier<T> defaultFactory, Map<Class<?>, Codec<?>> extraCodecs, Map<String, Codec<?>> codecOverwrites, Function<T, DataResult<T>> postProcessor) {
        return ofMap(clazz, defaultFactory, extraCodecs, codecOverwrites, postProcessor).codec();
        /*ClassFieldsReflectionCodec<C, T> fieldsCodec = new ClassFieldsReflectionCodec<>(clazz, defaultFactory, extraCodecs, codecOverwrites, postProcessor);
        Codec<Map<String, Object>> mapCodec = Codec.dispatchedMap(Codecs.NON_EMPTY_STRING, fieldsCodec::getFieldCodec);
        return mapCodec.flatXmap(fieldsCodec::createWithValues, fieldsCodec::getValues);*/
    }

    public static<C, T extends C> MapCodec<T> ofMap(Class<C> clazz, Supplier<T> defaultFactory, Map<Class<?>, Codec<?>> extraCodecs, Map<String, Codec<?>> codecOverwrites, Function<T, DataResult<T>> postProcessor) {
        ClassFieldsReflectionCodec<C, T> fieldsCodec = new ClassFieldsReflectionCodec<>(clazz, defaultFactory, extraCodecs, codecOverwrites, postProcessor);
        //MapCodec<Map<String, Object>> objects = new DispatchedMapCodec<>(fieldsCodec.classFields.keySet(), fieldsCodec::getFieldCodec);
        MapCodec<Map<String, Object>> objects = new DelegatedDispatchedMapCodec<>(fieldsCodec.classFields.keySet(), class_5699.field_41759, fieldsCodec::getFieldCodec);
        return objects.flatXmap(fieldsCodec::createWithValues, fieldsCodec::getValues);
    }

    private record SerializableField(Field field, Codec<?> codec, boolean required) {
        private Object get(Object instance) {
            try {
                return this.field.get(instance);
            } catch (IllegalAccessException e) {
                SimpleResourcesCommon.LOGGER.error("Could not retrieve the value of config field '" + this.field.getName() + "'", e);
                return null;
            }
        }

        private void set(Object instance, Object value) {
            try {
                this.field.set(instance, value);
            } catch (Exception e) {
                SimpleResourcesCommon.LOGGER.error("Could not assign value to the config field '" + this.field.getName() + "'", e);
            }
        }
    }

    public static class Builder<C, T extends C> implements ClassFieldsCodec.Builder<C, T> {
        private final Class<C> clazz;
        private final  Supplier<T> defaultFactory;
        private final Map<Class<?>, Codec<?>> extraCodecs = new HashMap<>();
        private final Map<String, Codec<?>> codecOverwrites = new HashMap<>();
        private Function<T, DataResult<T>> postProcessor = null;
        public Builder(Class<C> clazz, Supplier<T> defaultFactory) {
            this.clazz = clazz;
            this.defaultFactory = defaultFactory;
        }

        @Override
        public Builder<C, T> withCodec(Codec<?> codec, Class<?> clazz) throws IllegalArgumentException {
            if(this.extraCodecs.containsKey(clazz)) {
                throw new IllegalArgumentException("Attempted to register multiple codecs for the class '" + clazz.getName() +"'");
            }
            this.extraCodecs.put(clazz, codec);
            return this;
        }

        @Override
        public Builder<C, T> withCodec(Codec<?> codec, String... fieldIds) throws IllegalArgumentException {
            if(fieldIds.length == 0) {
                throw new IllegalArgumentException("Tried to register a codec for field ids, but did not specify any!");
            }

            for (String fieldId : fieldIds) {
                if(this.codecOverwrites.containsKey(fieldId)) {
                    throw new IllegalArgumentException("Attempted to register a duplicate codec for field id '" + fieldId +"'");
                }
                this.codecOverwrites.put(fieldId, codec);
            }
            return this;
        }

        @Override
        public Builder<C, T> postProcessor(Function<T, DataResult<T>> postProcessor) {
            if(this.postProcessor != null) {
                throw new IllegalStateException("Attempted to register a post processor while a post processor is already registered");
            }

            this.postProcessor = postProcessor;
            return this;
        }

        @Override
        public Codec<T> build() {
            return ClassFieldsReflectionCodec.of(this.clazz, this.defaultFactory, this.extraCodecs, this.codecOverwrites, this.postProcessor != null ? this.postProcessor : DataResult::success);
        }

        @Override
        public MapCodec<T> buildMap() {
            return ClassFieldsReflectionCodec.ofMap(this.clazz, this.defaultFactory, this.extraCodecs, this.codecOverwrites, this.postProcessor != null ? this.postProcessor : DataResult::success);
        }
    }
}
