package com.zurrtum.create.client.flywheel.impl.registry;

import com.zurrtum.create.client.flywheel.api.registry.IdRegistry;
import it.unimi.dsi.fastutil.objects.*;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.UnmodifiableView;

import java.util.Collection;
import java.util.Iterator;
import java.util.Set;
import net.minecraft.class_2960;

public class IdRegistryImpl<T> implements IdRegistry<T> {
    private static final ObjectList<IdRegistryImpl<?>> ALL = new ObjectArrayList<>();

    private final Object2ReferenceMap<class_2960, T> map = Object2ReferenceMaps.synchronize(new Object2ReferenceOpenHashMap<>());
    private final Reference2ObjectMap<T, class_2960> reverseMap = Reference2ObjectMaps.synchronize(new Reference2ObjectOpenHashMap<>());
    private final ObjectSet<class_2960> keysView = ObjectSets.unmodifiable(map.keySet());
    private final ReferenceCollection<T> valuesView = ReferenceCollections.unmodifiable(map.values());
    private boolean frozen;

    public IdRegistryImpl() {
        ALL.add(this);
    }

    @Override
    public void register(class_2960 id, T object) {
        if (frozen) {
            throw new IllegalStateException("Cannot register to frozen registry!");
        }
        T oldValue = map.put(id, object);
        if (oldValue != null) {
            throw new IllegalArgumentException("Cannot override registration for ID '" + id + "'!");
        }
        class_2960 oldId = reverseMap.put(object, id);
        if (oldId != null) {
            throw new IllegalArgumentException("Cannot override ID '" + id + "' with registration for ID '" + oldId + "'!");
        }
    }

    @Override
    public <S extends T> S registerAndGet(class_2960 id, S object) {
        register(id, object);
        return object;
    }

    @Override
    @Nullable
    public T get(class_2960 id) {
        return map.get(id);
    }

    @Override
    @Nullable
    public class_2960 getId(T object) {
        return reverseMap.get(object);
    }

    @Override
    public T getOrThrow(class_2960 id) {
        T object = get(id);
        if (object == null) {
            throw new IllegalArgumentException("Could not find object for ID '" + id + "'!");
        }
        return object;
    }

    @Override
    public class_2960 getIdOrThrow(T object) {
        class_2960 id = getId(object);
        if (id == null) {
            throw new IllegalArgumentException("Could not find ID for object!");
        }
        return id;
    }

    @Override
    @UnmodifiableView
    public Set<class_2960> getAllIds() {
        return keysView;
    }

    @Override
    @UnmodifiableView
    public Collection<T> getAll() {
        return valuesView;
    }

    @Override
    public boolean isFrozen() {
        return frozen;
    }

    @Override
    public Iterator<T> iterator() {
        return getAll().iterator();
    }

    private void freeze() {
        frozen = true;
    }

    public static void freezeAll() {
        for (IdRegistryImpl<?> registry : ALL) {
            registry.freeze();
        }
    }
}
