package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/modality/cv/translator/YoloSegmentationTranslator.class */
public class YoloSegmentationTranslator extends YoloV5Translator {
    private static final int[] AXIS_0 = {0};
    private static final int[] AXIS_1 = {1};
    private float threshold;
    private float nmsThreshold;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/modality/cv/translator/YoloSegmentationTranslator$Builder.class */
    public static class Builder extends YoloV5Translator.Builder {
        Builder() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.YoloV5Translator.Builder, ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        @Override // ai.djl.modality.cv.translator.YoloV5Translator.Builder
        public YoloSegmentationTranslator build() {
            validate();
            return new YoloSegmentationTranslator(this);
        }
    }

    public YoloSegmentationTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
    }

    @Override // ai.djl.modality.cv.translator.YoloV5Translator, ai.djl.translate.PostProcessor
    public DetectedObjects processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray nDArray = nDList.get(0);
        NDArray nDArray2 = nDList.get(1);
        int size = this.classes.size() + 4;
        NDArray gt = nDArray.get("4:" + size, new Object[0]).max(AXIS_0).gt(Float.valueOf(this.threshold));
        NDArray transpose = nDArray.transpose();
        NDList split = xywh2xyxy(transpose.get("..., :4", new Object[0])).concat(transpose.get("..., 4:", new Object[0]), -1).get(gt).split(new long[]{4, size}, 1);
        NDArray nDArray3 = split.get(0);
        int intExact = Math.toIntExact(nDArray3.getShape().get(0));
        float[] floatArray = nDArray3.toFloatArray();
        float[] floatArray2 = split.get(1).max(AXIS_1).toFloatArray();
        long[] longArray = split.get(1).argMax(1).toLongArray();
        ArrayList arrayList = new ArrayList(intExact);
        ArrayList arrayList2 = new ArrayList(intExact);
        for (int i = 0; i < intExact; i++) {
            arrayList.add(new Rectangle(floatArray[i * 4], floatArray[(i * 4) + 1], floatArray[(i * 4) + 2] - r0, floatArray[(i * 4) + 3] - r0));
            arrayList2.add(Double.valueOf(floatArray2[i]));
        }
        List<Integer> nms = Rectangle.nms(arrayList, arrayList2, this.nmsThreshold);
        long[] array = nms.stream().mapToLong((v0) -> {
            return v0.longValue();
        }).toArray();
        NDArray create = nDArray3.getManager().create(array);
        NDArray nDArray4 = split.get(2).get(create);
        int intExact2 = Math.toIntExact(nDArray2.getShape().get(2));
        int intExact3 = Math.toIntExact(nDArray2.getShape().get(1));
        float[] floatArray3 = nDArray4.matMul(nDArray2.reshape(32, intExact3 * intExact2)).reshape(nms.size(), intExact3, intExact2).gt(Float.valueOf(0.0f)).toType(DataType.FLOAT32, true).toFloatArray();
        float[] floatArray4 = nDArray3.get(create).toFloatArray();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        for (int i2 = 0; i2 < array.length; i2++) {
            float f = floatArray4[i2 * 4] / this.width;
            float f2 = floatArray4[(i2 * 4) + 1] / this.height;
            float f3 = (floatArray4[(i2 * 4) + 2] / this.width) - f;
            float f4 = (floatArray4[(i2 * 4) + 3] / this.width) - f2;
            arrayList3.add(this.classes.get((int) longArray[nms.get(i2).intValue()]));
            arrayList4.add(Double.valueOf(floatArray2[r0]));
            float[][] fArr = new float[intExact3][intExact2];
            for (int i3 = 0; i3 < intExact3; i3++) {
                System.arraycopy(floatArray3, i3 * intExact2, fArr[i3], 0, intExact2);
            }
            arrayList5.add(new Mask(f, f2, f3, f4, fArr, true));
        }
        return new DetectedObjects(arrayList3, arrayList4, arrayList5);
    }

    private NDArray xywh2xyxy(NDArray nDArray) {
        NDArray nDArray2 = nDArray.get("..., :2", new Object[0]);
        NDArray div = nDArray.get("..., 2:", new Object[0]).div((Number) 2);
        return nDArray2.sub(div).concat(nDArray2.add(div), -1);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
