/*
 * Decompiled with CFR 0.152.
 */
package com.ventooth.swansong.uniforms.compiler.frontend;

import com.ventooth.swansong.mathparser.AbstractParser;
import com.ventooth.swansong.uniforms.Type;
import com.ventooth.swansong.uniforms.UniformFunction;
import com.ventooth.swansong.uniforms.UniformFunctionRegistry;
import com.ventooth.swansong.uniforms.VecUtil;
import com.ventooth.swansong.uniforms.compiler.ast.ConstNode;
import com.ventooth.swansong.uniforms.compiler.ast.TypedNode;
import com.ventooth.swansong.uniforms.compiler.ast.UntypedNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedBoolNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedBranchNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedCastNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedFunctionNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedMathNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedMultiMatchNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedRelNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedUnaryMinusNode;
import com.ventooth.swansong.uniforms.compiler.ast.typed.TypedUnaryNotNode;
import com.ventooth.swansong.uniforms.compiler.ast.untyped.UntypedBinaryNode;
import com.ventooth.swansong.uniforms.compiler.ast.untyped.UntypedFunctionNode;
import com.ventooth.swansong.uniforms.compiler.ast.untyped.UntypedSwizzleNode;
import com.ventooth.swansong.uniforms.compiler.ast.untyped.UntypedUnaryNode;
import com.ventooth.swansong.uniforms.compiler.ast.untyped.UntypedVarNode;
import com.ventooth.swansong.uniforms.compiler.transform.Transformation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;

public class TypeResolver
implements Transformation<UntypedNode, TypedNode> {
    private final Flags flags;
    private final UniformFunctionRegistry registry;
    private int statefulIndexedCounter = 0;

    @Override
    public TypedNode transform(UntypedNode input) {
        if (input instanceof ConstNode) {
            ConstNode cst = (ConstNode)input;
            return cst;
        }
        if (input instanceof UntypedBinaryNode) {
            UntypedBinaryNode bin = (UntypedBinaryNode)input;
            return this.resolveBinary(this.transform(bin.left), this.transform(bin.right), bin.operator);
        }
        if (input instanceof UntypedFunctionNode) {
            UntypedFunctionNode fn = (UntypedFunctionNode)input;
            return this.resolveFunction(fn.name, this.transform(fn.params));
        }
        if (input instanceof UntypedSwizzleNode) {
            UntypedSwizzleNode swiz = (UntypedSwizzleNode)input;
            return this.resolveSwizzle(this.transform(swiz.value), swiz.index);
        }
        if (input instanceof UntypedUnaryNode) {
            UntypedUnaryNode un = (UntypedUnaryNode)input;
            return this.resolveUnary(this.transform(un.param), un.op);
        }
        if (input instanceof UntypedVarNode) {
            UntypedVarNode var = (UntypedVarNode)input;
            return this.resolveFunction(var.name, Collections.emptyList());
        }
        throw new AssertionError((Object)input.getClass().getName());
    }

    private TypedNode resolveBinary(TypedNode left, TypedNode right, AbstractParser.Operator operator) {
        TypedNode typedNode;
        switch (operator) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case Add: {
                typedNode = this.resolveMath(left, right, TypedMathNode.Op.Add);
                break;
            }
            case Sub: {
                typedNode = this.resolveMath(left, right, TypedMathNode.Op.Sub);
                break;
            }
            case Mul: {
                typedNode = this.resolveMath(left, right, TypedMathNode.Op.Mul);
                break;
            }
            case Div: {
                typedNode = this.resolveMath(left, right, TypedMathNode.Op.Div);
                break;
            }
            case Rem: {
                typedNode = this.resolveMath(left, right, TypedMathNode.Op.Rem);
                break;
            }
            case Eq: {
                typedNode = this.resolveRel(left, right, TypedRelNode.Op.Eq);
                break;
            }
            case Ne: {
                typedNode = this.resolveRel(left, right, TypedRelNode.Op.Ne);
                break;
            }
            case Ge: {
                typedNode = this.resolveRel(left, right, TypedRelNode.Op.Ge);
                break;
            }
            case Gt: {
                typedNode = this.resolveRel(left, right, TypedRelNode.Op.Gt);
                break;
            }
            case Le: {
                typedNode = this.resolveRel(left, right, TypedRelNode.Op.Le);
                break;
            }
            case Lt: {
                typedNode = this.resolveRel(left, right, TypedRelNode.Op.Lt);
                break;
            }
            case Or: {
                typedNode = this.resolveBool(Arrays.asList(left, right), TypedBoolNode.Op.Or);
                break;
            }
            case And: {
                typedNode = this.resolveBool(Arrays.asList(left, right), TypedBoolNode.Op.And);
            }
        }
        return typedNode;
    }

    private TypedNode resolveMath(TypedNode left, TypedNode right, TypedMathNode.Op op) {
        Pair<TypedNode, TypedNode> p = this.coerce(left, right);
        left = (TypedNode)p.getLeft();
        right = (TypedNode)p.getRight();
        switch (left.outputType()) {
            case Bool: {
                if (this.flags.castIntDivToFloat && op == TypedMathNode.Op.Div) {
                    left = new TypedCastNode(Type.Float, left);
                    right = new TypedCastNode(Type.Float, right);
                    break;
                }
                left = new TypedCastNode(Type.Int, left);
                right = new TypedCastNode(Type.Int, right);
                break;
            }
            case Int: {
                if (!this.flags.castIntDivToFloat || op != TypedMathNode.Op.Div) break;
                left = new TypedCastNode(Type.Float, left);
                right = new TypedCastNode(Type.Float, right);
                break;
            }
            case Vec2: 
            case Vec3: 
            case Vec4: {
                String string;
                switch (op) {
                    default: {
                        throw new IncompatibleClassChangeError();
                    }
                    case Add: {
                        string = "add";
                        break;
                    }
                    case Sub: {
                        string = "sub";
                        break;
                    }
                    case Mul: {
                        string = "mul";
                        break;
                    }
                    case Div: {
                        string = "div";
                        break;
                    }
                    case Rem: {
                        string = "rem";
                    }
                }
                String name = string;
                UniformFunction theOp = VecUtil.REGISTRY.resolve(name, Arrays.asList(left.outputType(), right.outputType()));
                if (theOp == null) {
                    throw new AssertionError((Object)("Unknown binary operation between vectors " + (Object)((Object)op)));
                }
                return new TypedFunctionNode(theOp, Arrays.asList(left, right));
            }
        }
        return new TypedMathNode(left, right, op);
    }

    private TypedNode resolveRel(TypedNode left, TypedNode right, TypedRelNode.Op op) {
        Pair<TypedNode, TypedNode> p = this.coerce(left, right);
        return new TypedRelNode((TypedNode)p.getLeft(), (TypedNode)p.getRight(), op);
    }

    private TypedNode resolveBool(List<TypedNode> nodes, TypedBoolNode.Op op) {
        for (TypedNode node : nodes) {
            if (node.outputType() == Type.Bool) continue;
            throw new IllegalArgumentException();
        }
        return new TypedBoolNode(nodes, op);
    }

    private TypedNode resolveUnary(TypedNode param, UntypedUnaryNode.Op op) {
        TypedNode typedNode;
        block0 : switch (op) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case Not: {
                typedNode = new TypedUnaryNotNode(param);
                break;
            }
            case Minus: {
                switch (param.outputType()) {
                    default: {
                        throw new IncompatibleClassChangeError();
                    }
                    case Bool: 
                    case Int: 
                    case Float: {
                        typedNode = new TypedUnaryMinusNode(param);
                        break block0;
                    }
                    case Vec2: 
                    case Vec3: 
                    case Vec4: 
                }
                UniformFunction theOp = VecUtil.REGISTRY.resolve("neg", Collections.singletonList(param.outputType()));
                if (theOp == null) {
                    throw new AssertionError((Object)"Could not find vector negation function!");
                }
                typedNode = new TypedFunctionNode(theOp, Collections.singletonList(param));
                break;
            }
        }
        return typedNode;
    }

    private TypedNode resolveFunction(String name, List<TypedNode> params) {
        int size = params.size();
        switch (name) {
            case "if": {
                if (size != 3) {
                    throw new IllegalArgumentException();
                }
                return this.resolveBranch(params.get(0), params.get(1), params.get(2));
            }
            case "in": {
                if (size < 2) {
                    throw new IllegalArgumentException();
                }
                return this.resolveMultiMatch(params);
            }
        }
        ArrayList<Type> types = new ArrayList<Type>(params.size());
        for (TypedNode param : params) {
            types.add(param.outputType());
        }
        UniformFunction function = this.registry.resolve(name, types);
        if (function != null) {
            return this.resolveFunction(function, params);
        }
        throw new IllegalStateException("Unknown uniform variable/function \"" + name + "\" with parameters: " + types);
    }

    private TypedNode resolveFunction(UniformFunction fn, List<TypedNode> params) {
        int size = params.size();
        List<Type> types = fn.params();
        ArrayList<TypedNode> newParams = new ArrayList<TypedNode>(size);
        int i = 0;
        if (fn.statefulIndexed() && !params.isEmpty() && params.get(0).outputType() == Type.Int) {
            i = 1;
            newParams.add(ConstNode.Int.of(this.statefulIndexedCounter++));
        }
        while (i < size) {
            TypedNode param = params.get(i);
            Type type = types.get(i);
            if (param.outputType() != type) {
                newParams.add(new TypedCastNode(type, param));
            } else {
                newParams.add(param);
            }
            ++i;
        }
        return new TypedFunctionNode(fn, Collections.unmodifiableList(newParams));
    }

    private TypedNode resolveBranch(TypedNode cond, TypedNode ifTrue, TypedNode ifFalse) {
        Pair<TypedNode, TypedNode> p = this.coerce(ifTrue, ifFalse);
        return new TypedBranchNode(cond, ifTrue, ifFalse);
    }

    private TypedNode resolveMultiMatch(List<TypedNode> elems) {
        ArrayList<TypedNode> resElems = new ArrayList<TypedNode>(elems);
        this.coerce(resElems);
        return new TypedMultiMatchNode(Collections.unmodifiableList(resElems));
    }

    private TypedNode resolveSwizzle(TypedNode value, int index) {
        UniformFunction fn = VecUtil.REGISTRY.resolve("swiz", Arrays.asList(value.outputType(), Type.Int));
        if (fn == null) {
            throw new IllegalStateException("Could not generate swizzle for value of type " + (Object)((Object)value.outputType()));
        }
        return new TypedFunctionNode(fn, Arrays.asList(value, ConstNode.Int.of(index)));
    }

    private void coerce(ArrayList<TypedNode> elems) {
        int i;
        if (elems.isEmpty()) {
            return;
        }
        int size = elems.size();
        Type outType = elems.get(0).outputType();
        for (i = 1; i < size; ++i) {
            outType = Type.coerce(outType, elems.get(i).outputType());
        }
        for (i = 0; i < size; ++i) {
            TypedNode elem = elems.get(i);
            if (elem.outputType() == outType) continue;
            elems.set(i, new TypedCastNode(outType, elem));
        }
    }

    private Pair<TypedNode, TypedNode> coerce(TypedNode left, TypedNode right) {
        Type rt;
        Type lt = left.outputType();
        Type outType = Type.coerce(lt, rt = right.outputType());
        if (outType != lt) {
            left = new TypedCastNode(outType, left);
        }
        if (outType != rt) {
            right = new TypedCastNode(outType, right);
        }
        return Pair.of((Object)left, (Object)right);
    }

    @Generated
    public TypeResolver(Flags flags, UniformFunctionRegistry registry) {
        this.flags = flags;
        this.registry = registry;
    }

    public static final class Flags {
        public final boolean castIntDivToFloat;

        @Generated
        public Flags(boolean castIntDivToFloat) {
            this.castIntDivToFloat = castIntDivToFloat;
        }
    }
}

