package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/loss/YOLOv3Loss.class */
public final class YOLOv3Loss extends Loss {
    private static final float[] PRESETANCHORS = {116.0f, 90.0f, 156.0f, 198.0f, 373.0f, 326.0f, 30.0f, 61.0f, 62.0f, 45.0f, 59.0f, 119.0f, 10.0f, 13.0f, 16.0f, 30.0f, 33.0f, 23.0f};
    private float[] anchors;
    private int numClasses;
    private int boxAttr;
    private Shape inputShape;
    private float ignoreThreshold;
    private NDManager manager;
    private static final float EPSILON = 1.0E-7f;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/training/loss/YOLOv3Loss$Builder.class */
    public static class Builder {
        private String name = "YOLOv3Loss";
        private float[] anchorsArray = YOLOv3Loss.PRESETANCHORS;
        private int numClasses = 20;
        private Shape inputShape = new Shape(419, 419);
        private float ignoreThreshold = 0.5f;

        public Builder setName(String str) {
            this.name = str;
            return this;
        }

        public Builder setAnchorsArray(float[] fArr) {
            if (fArr.length != YOLOv3Loss.PRESETANCHORS.length) {
                throw new IllegalArgumentException(String.format("setAnchorsArray requires anchors of length %d, but was given filters of length %d instead", Integer.valueOf(YOLOv3Loss.PRESETANCHORS.length), Integer.valueOf(fArr.length)));
            }
            this.anchorsArray = fArr;
            return this;
        }

        public Builder setNumClasses(int i) {
            this.numClasses = i;
            return this;
        }

        public Builder setInputShape(Shape shape) {
            this.inputShape = shape;
            return this;
        }

        public Builder optIgnoreThreshold(float f) {
            this.ignoreThreshold = f;
            return this;
        }

        public YOLOv3Loss build() {
            return new YOLOv3Loss(this);
        }
    }

    private YOLOv3Loss(Builder builder) {
        super(builder.name);
        this.anchors = builder.anchorsArray;
        this.numClasses = builder.numClasses;
        this.boxAttr = builder.numClasses + 5;
        this.inputShape = builder.inputShape;
        this.ignoreThreshold = builder.ignoreThreshold;
    }

    public static float[] getPresetAnchors() {
        return (float[]) PRESETANCHORS.clone();
    }

    public NDArray clipByTensor(NDArray nDArray, float f, float f2) {
        NDArray add = nDArray.gte(Float.valueOf(f)).mul(nDArray).add(nDArray.lt(Float.valueOf(f)).mul(Float.valueOf(f)));
        return add.lte(Float.valueOf(f2)).mul(add).add(add.gt(Float.valueOf(f2)).mul(Float.valueOf(f2)));
    }

    public NDArray mseLoss(NDArray nDArray, NDArray nDArray2) {
        return nDArray.sub(nDArray2).pow((Number) 2);
    }

    public NDArray bceLoss(NDArray nDArray, NDArray nDArray2) {
        NDArray clipByTensor = clipByTensor(nDArray, EPSILON, 0.9999999f);
        return clipByTensor.log().mul(nDArray2).add(clipByTensor.mul((Number) (-1)).add((Number) 1).log().mul(nDArray2.mul((Number) (-1)).add((Number) 1))).mul((Number) (-1));
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        this.manager = nDList2.getManager();
        NDArray[] nDArrayArr = new NDArray[3];
        for (int i = 0; i < 3; i++) {
            nDArrayArr[i] = evaluateOneOutput(i, nDList2.get(i), nDList.singletonOrThrow());
        }
        return NDArrays.add(nDArrayArr);
    }

    public NDArray evaluateOneOutput(int i, NDArray nDArray, NDArray nDArray2) {
        int i2 = (int) nDArray.getShape().get(0);
        int i3 = (int) nDArray.getShape().get(2);
        int i4 = (int) nDArray.getShape().get(3);
        NDArray transpose = nDArray.reshape(i2, 3, this.boxAttr, i3, i4).transpose(1, 0, 3, 4, 2);
        NDArray sigmoid = Activation.sigmoid(transpose.get("...,0", new Object[0]));
        NDArray sigmoid2 = Activation.sigmoid(transpose.get("...,1", new Object[0]));
        NDArray nDArray3 = transpose.get("...,2", new Object[0]);
        NDArray nDArray4 = transpose.get("...,3", new Object[0]);
        NDArray transpose2 = Activation.sigmoid(transpose.get("...,4", new Object[0])).transpose(1, 0, 2, 3);
        NDArray transpose3 = Activation.sigmoid(transpose.get("...,5:", new Object[0])).transpose(1, 0, 2, 3, 4);
        NDList target = getTarget(nDArray2, i4, i3);
        NDArray transpose4 = target.get(0).transpose(1, 0, 2, 3);
        NDArray nDArray5 = target.get(1);
        NDArray transpose5 = calculateIOU(sigmoid, sigmoid2, nDArray5.get("...,0:4", new Object[0]), i).transpose(1, 0, 2, 3);
        NDArray where = NDArrays.where(transpose5.lte(Float.valueOf(this.ignoreThreshold)), this.manager.ones(transpose5.getShape()), this.manager.create(0.0f));
        NDArray transpose6 = transpose5.argMax(1).oneHot(3).transpose(0, 3, 1, 2);
        NDArray where2 = NDArrays.where(transpose5.gte(Float.valueOf(this.ignoreThreshold / 2.0f)), transpose6, this.manager.zeros(transpose6.getShape()));
        NDArray where3 = NDArrays.where(where2.eq(Float.valueOf(1.0f)), this.manager.zeros(where.getShape()), where);
        return where2.mul(transpose4).mul(NDArrays.add(nDArray5.get("...,0", new Object[0]).sub(sigmoid).pow((Number) 2), nDArray5.get("...,1", new Object[0]).sub(sigmoid2).pow((Number) 2), nDArray5.get("...,2", new Object[0]).sub(nDArray3.exp().mul(this.manager.create(new float[]{this.anchors[i * 6], this.anchors[(i * 6) + 2], this.anchors[(i * 6) + 4]}).div(Long.valueOf(this.inputShape.get(0))).broadcast(i4, i3, i2, 3).transpose(3, 2, 1, 0))).pow((Number) 2), nDArray5.get("...,3", new Object[0]).sub(nDArray4.exp().mul(this.manager.create(new float[]{this.anchors[(i * 6) + 1], this.anchors[(i * 6) + 3], this.anchors[(i * 6) + 5]}).div(Long.valueOf(this.inputShape.get(1))).broadcast(i4, i3, i2, 3).transpose(3, 2, 1, 0))).pow((Number) 2)).transpose(1, 0, 2, 3)).sum().add(where2.mul(transpose2.add(Float.valueOf(EPSILON)).log().mul((Number) (-1)).add(bceLoss(transpose3, nDArray5.get("...,4:", new Object[0]).transpose(1, 0, 2, 3, 4)).sum(new int[]{4}))).sum()).add(where3.mul(transpose2.mul((Number) (-1)).add(Float.valueOf(1.0000001f)).log().mul((Number) (-1))).sum()).div(Integer.valueOf(i2));
    }

    public NDList getTarget(NDArray nDArray, int i, int i2) {
        int size = (int) nDArray.size(0);
        NDList nDList = new NDList();
        NDList nDList2 = new NDList();
        for (int i3 = 0; i3 < size; i3++) {
            if (nDArray.get(i3).size(0) != 0) {
                NDArray zeros = this.manager.zeros(new Shape(i2, i), DataType.FLOAT32);
                NDArray zeros2 = this.manager.zeros(new Shape(i2, i, this.boxAttr - 1), DataType.FLOAT32);
                NDArray nDArray2 = nDArray.get(i3);
                NDArray mul = nDArray2.get("...,1", new Object[0]).add(nDArray2.get("...,3", new Object[0]).div((Number) 2)).mul(Integer.valueOf(i2));
                NDArray mul2 = nDArray2.get("...,2", new Object[0]).add(nDArray2.get("...,4", new Object[0]).div((Number) 2)).mul(Integer.valueOf(i));
                NDArray nDArray3 = nDArray2.get("...,3", new Object[0]);
                NDArray nDArray4 = nDArray2.get("...,4", new Object[0]);
                NDArray oneHot = nDArray2.get("...,0", new Object[0]).oneHot(this.numClasses);
                int size2 = (int) nDArray.get(i3).size(0);
                for (int i4 = 0; i4 < size2; i4++) {
                    int i5 = (int) mul.get(i4).getFloat(new long[0]);
                    int i6 = (int) mul2.get(i4).getFloat(new long[0]);
                    float f = mul.get(i4).getFloat(new long[0]) - i5;
                    float f2 = mul2.get(i4).getFloat(new long[0]) - i6;
                    String str = i5 + "," + i6;
                    zeros2.set(new NDIndex(str + ",0", new Object[0]), Float.valueOf(f));
                    zeros2.set(new NDIndex(str + ",1", new Object[0]), Float.valueOf(f2));
                    zeros2.set(new NDIndex(str + ",2", new Object[0]), Float.valueOf(nDArray3.getFloat(i4)));
                    zeros2.set(new NDIndex(str + ",3", new Object[0]), Float.valueOf(nDArray4.getFloat(i4)));
                    zeros2.set(new NDIndex(str + ",4:", new Object[0]), oneHot.get(i4));
                    zeros.set(new NDIndex(str, new Object[0]), Float.valueOf(2.0f - (nDArray3.getFloat(i4) * nDArray4.getFloat(i4))));
                }
                nDList.add(zeros);
                nDList2.add(zeros2);
            }
        }
        return new NDList(NDArrays.stack(nDList).broadcast(3, size, i2, i), NDArrays.stack(nDList2).broadcast(3, size, i2, i, this.boxAttr - 1));
    }

    public NDArray calculateIOU(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, int i) {
        int i2 = (int) nDArray.getShape().get(2);
        int i3 = (int) nDArray.getShape().get(3);
        int i4 = ((int) this.inputShape.get(0)) / i2;
        int i5 = ((int) this.inputShape.get(1)) / i3;
        NDList nDList = new NDList();
        for (int i6 = 0; i6 < 3; i6++) {
            NDArray nDArray4 = nDArray.get(i6);
            NDArray nDArray5 = nDArray2.get(i6);
            float f = this.anchors[(i * 6) + (2 * i6)] / i4;
            float f2 = this.anchors[((i * 6) + (2 * i6)) + 1] / i5;
            NDArray sub = nDArray4.sub(Float.valueOf(f / 2.0f));
            NDArray add = nDArray4.add(Float.valueOf(f / 2.0f));
            NDArray sub2 = nDArray5.sub(Float.valueOf(f2 / 2.0f));
            NDArray add2 = nDArray5.add(Float.valueOf(f2 / 2.0f));
            NDArray nDArray6 = nDArray3.get(i6);
            NDArray sub3 = nDArray6.get("...,0", new Object[0]).sub(nDArray6.get("...,2", new Object[0]).mul(Integer.valueOf(i2)).div((Number) 2));
            NDArray add3 = nDArray6.get("...,0", new Object[0]).add(nDArray6.get("...,2", new Object[0]).mul(Integer.valueOf(i2)).div((Number) 2));
            NDArray sub4 = nDArray6.get("...,1", new Object[0]).sub(nDArray6.get("...,3", new Object[0]).mul(Integer.valueOf(i3)).div((Number) 2));
            NDArray add4 = nDArray6.get("...,1", new Object[0]).add(nDArray6.get("...,3", new Object[0]).mul(Integer.valueOf(i3)).div((Number) 2));
            NDArray mul = NDArrays.minimum(add, add3).sub(NDArrays.maximum(sub, sub3)).mul(NDArrays.minimum(add2, add4).sub(NDArrays.maximum(sub2, sub4)));
            nDList.add(mul.div(nDArray6.get("...,2", new Object[0]).mul(Integer.valueOf(i2)).mul(nDArray6.get("...,3", new Object[0]).mul(Integer.valueOf(i3))).add(Float.valueOf(f * f2)).sub(mul).add(Float.valueOf(EPSILON))));
        }
        return NDArrays.stack(nDList);
    }

    public static Builder builder() {
        return new Builder();
    }
}
