package smile.base.cart;

import java.io.Serializable;
import java.math.BigInteger;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.sort.QuickSort;

/* loaded from: input_file:smile/base/cart/CART.class */
public abstract class CART implements SHAP<Tuple>, Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger;
    protected Formula formula;
    protected StructType schema;
    protected StructField response;
    protected Node root;
    protected int maxDepth;
    protected int maxNodes;
    protected int nodeSize;
    protected int mtry;
    protected double[] importance;
    protected transient DataFrame x;
    protected transient int[] samples;
    protected transient int[] index;
    protected transient int[][] order;
    private transient int[] buffer;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/base/cart/CART$Path.class */
    public static class Path {
        int length;
        final int[] d;
        final double[] z;
        final double[] o;
        final double[] w;

        Path(int[] iArr, double[] dArr, double[] dArr2, double[] dArr3) {
            this.length = iArr.length;
            this.d = iArr;
            this.z = dArr;
            this.o = dArr2;
            this.w = dArr3;
        }

        Path extend(double d, double d2, int i) {
            int i2 = this.length;
            Path path = new Path(Arrays.copyOf(this.d, i2 + 1), Arrays.copyOf(this.z, i2 + 1), Arrays.copyOf(this.o, i2 + 1), Arrays.copyOf(this.w, i2 + 1));
            path.d[i2] = i;
            path.z[i2] = d;
            path.o[i2] = d2;
            path.w[i2] = i2 == 0 ? 1.0d : 0.0d;
            for (int i3 = i2 - 1; i3 >= 0; i3--) {
                double[] dArr = path.w;
                int i4 = i3 + 1;
                dArr[i4] = dArr[i4] + (((d2 * path.w[i3]) * (i3 + 1)) / (i2 + 1));
                path.w[i3] = ((d * path.w[i3]) * (i2 - i3)) / (i2 + 1);
            }
            return path;
        }

        void unwind(int i) {
            double d = this.o[i];
            double d2 = this.z[i];
            int i2 = this.length - 1;
            this.length = i2;
            double d3 = this.w[i2];
            if (d != 0.0d) {
                for (int i3 = i2 - 1; i3 >= 0; i3--) {
                    double d4 = this.w[i3];
                    this.w[i3] = (d3 * (i2 + 1)) / ((i3 + 1) * d);
                    d3 = d4 - (((this.w[i3] * d2) * (i2 - i3)) / (i2 + 1));
                }
            } else {
                for (int i4 = i2 - 1; i4 >= 0; i4--) {
                    this.w[i4] = (this.w[i4] * (i2 + 1)) / (d2 * (i2 - i4));
                }
            }
            for (int i5 = i; i5 < i2; i5++) {
                this.d[i5] = this.d[i5 + 1];
                this.z[i5] = this.z[i5 + 1];
                this.o[i5] = this.o[i5 + 1];
            }
        }

        double unwoundSum(int i) {
            double d = this.o[i];
            double d2 = this.z[i];
            int i2 = this.length - 1;
            double d3 = 0.0d;
            double d4 = this.w[i2];
            if (d != 0.0d) {
                for (int i3 = i2 - 1; i3 >= 0; i3--) {
                    double d5 = d4 / ((i3 + 1) * d);
                    d3 += d5;
                    d4 = this.w[i3] - ((d5 * d2) * (i2 - i3));
                }
            } else {
                for (int i4 = i2 - 1; i4 >= 0; i4--) {
                    d3 += this.w[i4] / (d2 * (i2 - i4));
                }
            }
            return d3 * (i2 + 1);
        }
    }

    private CART() {
        this.maxDepth = 20;
        this.maxNodes = 6;
        this.nodeSize = 5;
        this.mtry = -1;
    }

    public CART(Formula formula, StructType structType, StructField structField, Node node, double[] dArr) {
        this.maxDepth = 20;
        this.maxNodes = 6;
        this.nodeSize = 5;
        this.mtry = -1;
        this.formula = formula;
        this.schema = structType;
        this.response = structField;
        this.root = node;
        this.importance = dArr;
    }

    /* JADX WARN: Type inference failed for: r1v33, types: [int[], int[][]] */
    public CART(DataFrame dataFrame, StructField structField, int i, int i2, int i3, int i4, int[] iArr, int[][] iArr2) {
        IntStream filter;
        this.maxDepth = 20;
        this.maxNodes = 6;
        this.nodeSize = 5;
        this.mtry = -1;
        int nrow = dataFrame.nrow();
        int ncol = dataFrame.ncol();
        this.x = dataFrame;
        this.response = structField;
        this.schema = dataFrame.schema();
        this.importance = new double[ncol];
        this.maxDepth = i;
        this.maxNodes = i2;
        this.nodeSize = i3;
        this.mtry = i4;
        if (i4 < 1 || i4 > ncol) {
            logger.warn("Invalid mtry. Use all features.");
            this.mtry = ncol;
        }
        if (i < 1) {
            throw new IllegalArgumentException("Invalid maximum depth: " + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + i2);
        }
        if (i3 < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i3);
        }
        if (iArr == null) {
            this.samples = Collections.nCopies(nrow, 1).parallelStream().mapToInt(num -> {
                return num.intValue();
            }).toArray();
            filter = IntStream.range(0, nrow);
        } else {
            this.samples = iArr;
            filter = IntStream.range(0, iArr.length).filter(i5 -> {
                return iArr[i5] > 0;
            });
        }
        this.index = filter.toArray();
        this.buffer = new int[this.index.length];
        if (iArr2 == null) {
            this.order = order(dataFrame);
            return;
        }
        this.order = new int[iArr2.length];
        for (int i6 = 0; i6 < iArr2.length; i6++) {
            if (iArr2[i6] != null) {
                this.order[i6] = Arrays.stream(iArr2[i6]).filter(i7 -> {
                    return this.samples[i7] > 0;
                }).toArray();
            }
        }
    }

    public int size() {
        return size(this.root);
    }

    private int size(Node node) {
        if (node instanceof LeafNode) {
            return 1;
        }
        InternalNode internalNode = (InternalNode) node;
        return size(internalNode.trueChild) + size(internalNode.falseChild) + 1;
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    public static int[][] order(DataFrame dataFrame) {
        dataFrame.size();
        int ncol = dataFrame.ncol();
        ?? r0 = new int[ncol];
        StructType schema = dataFrame.schema();
        for (int i = 0; i < ncol; i++) {
            if (!(schema.field(i).measure() instanceof NominalScale)) {
                r0[i] = QuickSort.sort(dataFrame.column(i).toDoubleArray());
            }
        }
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tuple predictors(Tuple tuple) {
        return this.formula == null ? tuple : this.formula.x(tuple);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void clear() {
        this.x = null;
        this.order = null;
        this.index = null;
        this.samples = null;
        this.buffer = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean split(Split split, PriorityQueue<Split> priorityQueue) {
        if (split.feature < 0) {
            throw new IllegalStateException("Split a node with invalid feature.");
        }
        if (split.depth >= this.maxDepth) {
            logger.debug("Reach maximum depth");
            return false;
        }
        if (split.trueCount < this.nodeSize || split.falseCount < this.nodeSize) {
            logger.debug("Node size is too small after splitting");
            return false;
        }
        int[] array = Arrays.stream(this.index, split.lo, split.hi).filter(i -> {
            return split.predicate().test(i);
        }).toArray();
        boolean[] zArr = new boolean[this.samples.length];
        for (int i2 : array) {
            zArr[i2] = true;
        }
        int[] array2 = Arrays.stream(this.index, split.lo, split.hi).filter(i3 -> {
            return !zArr[i3];
        }).toArray();
        int length = split.lo + array.length;
        LeafNode newNode = newNode(array);
        if (!$assertionsDisabled && newNode.size != split.trueCount) {
            throw new AssertionError(String.format("trueChild.size != split.trueCount: %d != %d", Integer.valueOf(newNode.size), Integer.valueOf(split.trueCount)));
        }
        if (!$assertionsDisabled && newNode.size < this.nodeSize) {
            throw new AssertionError(String.format("trueChild size is too small: %d < %d", Integer.valueOf(newNode.size), Integer.valueOf(this.nodeSize)));
        }
        LeafNode newNode2 = newNode(array2);
        if (!$assertionsDisabled && newNode2.size != split.falseCount) {
            throw new AssertionError(String.format("falseChild.size != split.falseCount: %d != %d", Integer.valueOf(newNode2.size), Integer.valueOf(split.falseCount)));
        }
        if (!$assertionsDisabled && newNode2.size < this.nodeSize) {
            throw new AssertionError(String.format("falseChild size is too small: %d < %d", Integer.valueOf(newNode2.size), Integer.valueOf(this.nodeSize)));
        }
        InternalNode node = split.toNode(newNode, newNode2);
        shuffle(split.lo, length, split.hi, zArr);
        Optional<Split> findBestSplit = findBestSplit(newNode, split.lo, length, (boolean[]) split.unsplittable.clone());
        Optional<Split> findBestSplit2 = findBestSplit(newNode2, length, split.hi, split.unsplittable);
        if (newNode.equals(newNode2) && findBestSplit.isEmpty() && findBestSplit2.isEmpty()) {
            return false;
        }
        if (split.parent == null) {
            this.root = node;
        } else if (split.parent.trueChild == split.leaf) {
            split.parent.trueChild = node;
        } else {
            if (split.parent.falseChild != split.leaf) {
                throw new IllegalStateException("split.parent and leaf don't match");
            }
            split.parent.falseChild = node;
        }
        double[] dArr = this.importance;
        int i4 = node.feature;
        dArr[i4] = dArr[i4] + node.score;
        findBestSplit.ifPresent(split2 -> {
            split2.parent = node;
            split2.depth = split.depth + 1;
        });
        findBestSplit2.ifPresent(split3 -> {
            split3.parent = node;
            split3.depth = split.depth + 1;
        });
        if (priorityQueue == null) {
            findBestSplit.ifPresent(split4 -> {
                split(split4, null);
            });
            findBestSplit2.ifPresent(split5 -> {
                split(split5, null);
            });
            return true;
        }
        Objects.requireNonNull(priorityQueue);
        findBestSplit.ifPresent((v1) -> {
            r1.add(v1);
        });
        Objects.requireNonNull(priorityQueue);
        findBestSplit2.ifPresent((v1) -> {
            r1.add(v1);
        });
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Optional<Split> findBestSplit(LeafNode leafNode, int i, int i2, boolean[] zArr) {
        if (leafNode.size() < 2 * this.nodeSize) {
            return Optional.empty();
        }
        double impurity = impurity(leafNode);
        if (impurity == 0.0d) {
            return Optional.empty();
        }
        int length = this.schema.length();
        int[] array = IntStream.range(0, length).filter(i3 -> {
            return !zArr[i3];
        }).toArray();
        if (this.mtry < length) {
            MathEx.permutate(array);
        }
        IntStream limit = Arrays.stream(array).limit(this.mtry);
        Optional<Split> max = (this.mtry < length ? limit : limit.parallel()).mapToObj(i4 -> {
            Optional<Split> findBestSplit = findBestSplit(leafNode, i4, impurity, i, i2);
            if (findBestSplit.isEmpty()) {
                zArr[i4] = true;
            }
            return findBestSplit;
        }).filter((v0) -> {
            return v0.isPresent();
        }).map((v0) -> {
            return v0.get();
        }).max(Split.comparator);
        max.ifPresent(split -> {
            split.unsplittable = zArr;
        });
        return max;
    }

    protected abstract double impurity(LeafNode leafNode);

    protected abstract LeafNode newNode(int[] iArr);

    protected abstract Optional<Split> findBestSplit(LeafNode leafNode, int i, double d, int i2, int i3);

    public double[] importance() {
        return this.importance;
    }

    public Node root() {
        return this.root;
    }

    public String dot() {
        StringBuilder sb = new StringBuilder();
        sb.append("digraph CART {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
        String str = " [labeldistance=2.5, labelangle=45, headlabel=\"True\"];\n";
        String str2 = " [labeldistance=2.5, labelangle=-45, headlabel=\"False\"];\n";
        LinkedList linkedList = new LinkedList();
        linkedList.add(new AbstractMap.SimpleEntry(1, this.root));
        while (!linkedList.isEmpty()) {
            AbstractMap.SimpleEntry simpleEntry = (AbstractMap.SimpleEntry) linkedList.poll();
            int intValue = ((Integer) simpleEntry.getKey()).intValue();
            Node node = (Node) simpleEntry.getValue();
            sb.append(node.dot(this.schema, this.response, intValue));
            if (node instanceof InternalNode) {
                InternalNode internalNode = (InternalNode) node;
                int i = 2 * intValue;
                int i2 = (2 * intValue) + 1;
                linkedList.add(new AbstractMap.SimpleEntry(Integer.valueOf(i), internalNode.trueChild));
                linkedList.add(new AbstractMap.SimpleEntry(Integer.valueOf(i2), internalNode.falseChild));
                sb.append(' ').append(intValue).append(" -> ").append(i).append(str);
                sb.append(' ').append(intValue).append(" -> ").append(i2).append(str2);
                if (intValue == 1) {
                    str = "\n";
                    str2 = "\n";
                }
            }
        }
        sb.append("}");
        return sb.toString();
    }

    private void shuffle(int i, int i2, int i3, boolean[] zArr) {
        for (int[] iArr : this.order) {
            if (iArr != null) {
                shuffle(iArr, i, i2, i3, zArr);
            }
        }
        shuffle(this.index, i, i2, i3, zArr);
    }

    private void shuffle(int[] iArr, int i, int i2, int i3, boolean[] zArr) {
        int i4 = 0;
        int i5 = i;
        for (int i6 = i; i6 < i3; i6++) {
            if (zArr[iArr[i6]]) {
                int i7 = i5;
                i5++;
                iArr[i7] = iArr[i6];
            } else {
                int i8 = i4;
                i4++;
                this.buffer[i8] = iArr[i6];
            }
        }
        if (!$assertionsDisabled && i2 + i4 != i3) {
            throw new AssertionError();
        }
        System.arraycopy(this.buffer, 0, iArr, i2, i4);
    }

    public String toString() {
        ArrayList arrayList = new ArrayList();
        this.root.toString(this.schema, this.response, null, 0, BigInteger.ONE, arrayList);
        arrayList.add("* denotes terminal node");
        arrayList.add("node), split, n, loss, yval, (yprob)");
        arrayList.add("n=" + this.root.size());
        Collections.reverse(arrayList);
        return String.join("\n", arrayList);
    }

    public double[] shap(DataFrame dataFrame) {
        this.formula.bind(dataFrame.schema());
        return shap((Stream) dataFrame.stream().parallel());
    }

    @Override // smile.feature.importance.SHAP
    public double[] shap(Tuple tuple) {
        Node node;
        int i = 1;
        Node node2 = this.root;
        while (true) {
            node = node2;
            if (!(node instanceof InternalNode)) {
                break;
            }
            node2 = ((InternalNode) node).trueChild;
        }
        if (node instanceof DecisionNode) {
            i = ((DecisionNode) node).count().length;
        }
        double[] dArr = new double[this.schema.length() * i];
        recurse(dArr, predictors(tuple), this.root, new Path(new int[0], new double[0], new double[0], new double[0]), 1.0d, 1.0d, -1);
        return dArr;
    }

    private void recurse(double[] dArr, Tuple tuple, Node node, Path path, double d, double d2, int i) {
        Node falseChild;
        Node trueChild;
        int i2 = path.length;
        Path extend = path.extend(d, d2, i);
        if (node instanceof InternalNode) {
            InternalNode internalNode = (InternalNode) node;
            int feature = internalNode.feature();
            if (internalNode.branch(tuple)) {
                falseChild = internalNode.trueChild();
                trueChild = internalNode.falseChild();
            } else {
                falseChild = internalNode.falseChild();
                trueChild = internalNode.trueChild();
            }
            int size = falseChild.size();
            int size2 = trueChild.size();
            int size3 = node.size();
            int i3 = 0;
            while (i3 <= i2 && extend.d[i3] != feature) {
                i3++;
            }
            double d3 = 1.0d;
            double d4 = 1.0d;
            if (i3 <= i2) {
                d3 = extend.z[i3];
                d4 = extend.o[i3];
                extend.unwind(i3);
            }
            recurse(dArr, tuple, falseChild, extend, (d3 * size) / size3, d4, feature);
            recurse(dArr, tuple, trueChild, extend, (d3 * size2) / size3, 0.0d, feature);
            return;
        }
        if (!(node instanceof DecisionNode)) {
            double output = ((RegressionNode) node).output();
            for (int i4 = 1; i4 <= i2; i4++) {
                double unwoundSum = extend.unwoundSum(i4);
                int i5 = extend.d[i4];
                dArr[i5] = dArr[i5] + (unwoundSum * (extend.o[i4] - extend.z[i4]) * output);
            }
            return;
        }
        DecisionNode decisionNode = (DecisionNode) node;
        int length = decisionNode.count().length;
        double[] dArr2 = new double[length];
        decisionNode.posteriori(dArr2);
        for (int i6 = 1; i6 <= i2; i6++) {
            double unwoundSum2 = extend.unwoundSum(i6) * (extend.o[i6] - extend.z[i6]);
            int i7 = extend.d[i6] * length;
            for (int i8 = 0; i8 < length; i8++) {
                int i9 = i7 + i8;
                dArr[i9] = dArr[i9] + (unwoundSum2 * dArr2[i8]);
            }
        }
    }

    static {
        $assertionsDisabled = !CART.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(CART.class);
    }
}
