package io.github.beardedManZhao.algorithmStar.core.model;

import io.github.beardedManZhao.algorithmStar.exception.OperatorOperationException;
import io.github.beardedManZhao.algorithmStar.operands.matrix.ColorMatrix;
import io.github.beardedManZhao.algorithmStar.operands.matrix.block.DoubleMatrixSpace;
import io.github.beardedManZhao.algorithmStar.operands.matrix.block.IntegerMatrixSpace;
import io.github.beardedManZhao.algorithmStar.operands.table.Cell;
import io.github.beardedManZhao.algorithmStar.operands.vector.DoubleVector;
import io.github.beardedManZhao.algorithmStar.utils.ASClass;
import io.github.beardedManZhao.algorithmStar.utils.dataContainer.KeyValue;
import io.github.beardedManZhao.algorithmStar.utils.transformation.Transformation;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:META-INF/jars/algorithmStar-1.44.jar:io/github/beardedManZhao/algorithmStar/core/model/SingleLayerCNNModel.class */
public final class SingleLayerCNNModel implements ASModel<Integer, IntegerMatrixSpace, ClassificationModel<IntegerMatrixSpace>> {
    public static final int COLOR_CHANNEL = 2;
    public static final int Activation_Function = 3;
    public static final int KERNEL = 4;
    public static final int TRANSFORMATION = 5;
    public static final int LEARN_COUNT = 6;
    public static final int LEARNING_RATE = 7;
    private static final long serialVersionUID = -8785883488441048941L;
    private DoubleMatrixSpace kernel;
    private DoubleVector[] target;
    private final Random random = new Random();
    private final ArrayList<KeyValue<String, DoubleVector>> W1 = new ArrayList<>();
    LossFunction lossFunction = LossFunction.MSE;
    private int kw = 0;
    private int kh = 0;
    private int ww = -1;
    private int wh = -1;
    private int colorChannel = 8;
    private int learnCount = 100;
    private float learningRate = 0.2f;
    private ActivationFunction activationFunction = ActivationFunction.RELU;
    private Transformation<ColorMatrix, ColorMatrix> transformation = colorMatrix -> {
        return colorMatrix;
    };

    /* loaded from: input_file:META-INF/jars/algorithmStar-1.44.jar:io/github/beardedManZhao/algorithmStar/core/model/SingleLayerCNNModel$TaskConsumer.class */
    public interface TaskConsumer extends Serializable {
        public static final TaskConsumer VOID = VoidTask.VOID_TASK;

        void accept(double d, double[] dArr, List<KeyValue<String, DoubleVector>> list);
    }

    @NotNull
    private static ClassificationModel<IntegerMatrixSpace> getModel(final ListNeuralNetworkLayer listNeuralNetworkLayer, final Transformation<ColorMatrix, ColorMatrix> transformation, final int i, final int i2, final int i3, final int i4, final int i5, final DoubleMatrixSpace doubleMatrixSpace) {
        return new ClassificationModel<IntegerMatrixSpace>() { // from class: io.github.beardedManZhao.algorithmStar.core.model.SingleLayerCNNModel.1
            final ListNeuralNetworkLayer lnn;
            final Transformation<ColorMatrix, ColorMatrix> tf;
            final String[] names;
            private final int kw1;
            private final int kh1;
            private final int ww1;
            private final int wh1;
            private final int cc;
            private final DoubleMatrixSpace kernel1;

            {
                this.lnn = ListNeuralNetworkLayer.this;
                this.tf = transformation;
                this.names = new String[this.lnn.size()];
                this.kw1 = i;
                this.kh1 = i2;
                this.ww1 = i3;
                this.wh1 = i4;
                this.cc = i5;
                this.kernel1 = doubleMatrixSpace;
                int i6 = -1;
                Iterator<Perceptron> it = this.lnn.iterator();
                while (it.hasNext()) {
                    i6++;
                    this.names[i6] = it.next().getName();
                }
            }

            /* renamed from: setArg, reason: avoid collision after fix types in other method */
            public void setArg2(Integer num, @NotNull Cell<?> cell) {
            }

            @Override // io.github.beardedManZhao.algorithmStar.core.model.ASModel
            public KeyValue<String[], DoubleVector[]> function(IntegerMatrixSpace... integerMatrixSpaceArr) {
                DoubleVector[] doubleVectorArr = new DoubleVector[integerMatrixSpaceArr.length];
                int i6 = -1;
                for (IntegerMatrixSpace integerMatrixSpace : integerMatrixSpaceArr) {
                    if (integerMatrixSpace.getColCount() != this.ww1 || integerMatrixSpace.getRowCount() != this.wh1) {
                        throw new OperatorOperationException("The image matrix you provided cannot be used for the current model. The current model supports w * h = [" + this.ww1 + " * " + this.wh1 + "]");
                    }
                    i6++;
                    doubleVectorArr[i6] = this.lnn.forward(DoubleVector.parse(this.tf.function(integerMatrixSpace.foldingAndSumRGB(this.kw1, this.kh1, this.kernel1)).getChannel(this.cc).flatten()));
                }
                return new KeyValue<>(this.names, doubleVectorArr);
            }

            @Override // io.github.beardedManZhao.algorithmStar.core.model.ASModel
            public KeyValue<String[], DoubleVector[]> functionConcurrency(IntegerMatrixSpace... integerMatrixSpaceArr) {
                DoubleVector[] doubleVectorArr = new DoubleVector[integerMatrixSpaceArr.length];
                int[] iArr = {-1};
                CountDownLatch countDownLatch = new CountDownLatch(integerMatrixSpaceArr.length);
                for (IntegerMatrixSpace integerMatrixSpace : integerMatrixSpaceArr) {
                    if (integerMatrixSpace.getColCount() != this.ww1 || integerMatrixSpace.getRowCount() != this.wh1) {
                        throw new OperatorOperationException("The image matrix you provided cannot be used for the current model. The current model supports w * h = [" + this.ww1 + " * " + this.wh1 + "]");
                    }
                    new Thread(() -> {
                        int i6 = iArr[0] + 1;
                        iArr[0] = i6;
                        doubleVectorArr[i6] = this.lnn.forward(DoubleVector.parse(this.tf.function(integerMatrixSpace.foldingAndSumRGB(this.kw1, this.kh1, this.kernel1)).getChannel(this.cc).flatten()));
                        countDownLatch.countDown();
                    }).start();
                    try {
                        countDownLatch.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
                return new KeyValue<>(this.names, doubleVectorArr);
            }

            @Override // io.github.beardedManZhao.algorithmStar.core.model.ASModel
            public /* bridge */ /* synthetic */ void setArg(Integer num, @NotNull Cell cell) {
                setArg2(num, (Cell<?>) cell);
            }
        };
    }

    public void setLossFunction(LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public void setKernel(DoubleMatrixSpace doubleMatrixSpace) {
        this.kernel = doubleMatrixSpace;
        this.kh = doubleMatrixSpace.getRowCount();
        this.kw = doubleMatrixSpace.getColCount();
    }

    public void setColorChannel(int i) {
        this.colorChannel = i;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public void setTransformation(Transformation<ColorMatrix, ColorMatrix> transformation) {
        this.transformation = transformation;
    }

    public void setLearnCount(int i) {
        this.learnCount = i;
    }

    public void setWeight(DoubleVector[] doubleVectorArr, List<KeyValue<String, IntegerMatrixSpace>> list) {
        if (doubleVectorArr.length != list.size()) {
            throw new OperatorOperationException("You should ensure that the quantity of target data is consistent with the weight data.");
        }
        this.W1.clear();
        int i = -1;
        int i2 = -1;
        for (KeyValue<String, IntegerMatrixSpace> keyValue : list) {
            IntegerMatrixSpace value = keyValue.getValue();
            if (i == -1) {
                i = value.getColCount();
                i2 = value.getRowCount();
            } else if (i != value.getColCount() || i2 != value.getRowCount()) {
                throw new OperatorOperationException("Please provide ya with consistent length and width.");
            }
            this.W1.add(new KeyValue<>(keyValue.getKey(), DoubleVector.parse(this.transformation.function(value.foldingAndSumRGB(this.kw, this.kh, this.kernel)).getChannel(this.colorChannel).flatten())));
        }
        this.ww = i;
        this.wh = i2;
        this.target = doubleVectorArr;
    }

    public void setLearningRate(float f) {
        this.learningRate = f;
    }

    /* renamed from: setArg, reason: avoid collision after fix types in other method */
    public void setArg2(Integer num, @NotNull Cell<?> cell) {
        switch (num.intValue()) {
            case 2:
                setColorChannel(cell.getIntValue());
                return;
            case 3:
                if (cell.getValue() instanceof String) {
                    setActivationFunction(ActivationFunction.valueOf(cell.getStringValue()));
                    return;
                } else {
                    if (!(cell.getValue() instanceof ActivationFunction)) {
                        throw new OperatorOperationException("setActivationFunction((ActivationFunction or String) value.getValue()) error !!!");
                    }
                    setActivationFunction((ActivationFunction) cell.getValue());
                    return;
                }
            case 4:
                setKernel((DoubleMatrixSpace) cell.getValue());
                return;
            case 5:
                setTransformation((Transformation) ASClass.transform(cell.getValue()));
                return;
            case 6:
                setLearnCount(cell.getIntValue());
                break;
            case 7:
                break;
            default:
                return;
        }
        setLearningRate((float) cell.getDoubleValue());
    }

    @Override // io.github.beardedManZhao.algorithmStar.core.model.ASModel
    public ClassificationModel<IntegerMatrixSpace> function(IntegerMatrixSpace... integerMatrixSpaceArr) {
        return function(TaskConsumer.VOID, integerMatrixSpaceArr);
    }

    public ClassificationModel<IntegerMatrixSpace> function(TaskConsumer taskConsumer, IntegerMatrixSpace... integerMatrixSpaceArr) {
        DoubleVector[] doubleVectorArr = new DoubleVector[integerMatrixSpaceArr.length];
        int i = -1;
        for (IntegerMatrixSpace integerMatrixSpace : integerMatrixSpaceArr) {
            i++;
            doubleVectorArr[i] = DoubleVector.parse(this.transformation.function(integerMatrixSpace.foldingAndSumRGB(this.kw, this.kh, this.kernel)).getChannel(this.colorChannel).flatten());
        }
        ListNeuralNetworkLayer listNeuralNetworkLayer = new ListNeuralNetworkLayer();
        Iterator<KeyValue<String, DoubleVector>> it = this.W1.iterator();
        while (it.hasNext()) {
            KeyValue<String, DoubleVector> next = it.next();
            listNeuralNetworkLayer.addPerceptron(Perceptron.parse(next.getKey(), this.activationFunction, next.getValue()));
        }
        for (int i2 = 0; i2 < this.learnCount; i2++) {
            int nextInt = this.random.nextInt(i);
            double function = this.lossFunction.function(listNeuralNetworkLayer.forward(doubleVectorArr[nextInt]).toArray(), this.target[nextInt].toArray());
            double[] array = listNeuralNetworkLayer.backForward(DoubleVector.parse(function)).toArray();
            taskConsumer.accept(function, array, this.W1);
            int i3 = -1;
            for (double d : array) {
                double d2 = this.learningRate * d;
                i3++;
                double[] array2 = this.W1.get(i3).getValue().toArray();
                for (int i4 = 0; i4 < array2.length; i4++) {
                    int i5 = i4;
                    array2[i5] = array2[i5] - d2;
                }
            }
        }
        return getModel(listNeuralNetworkLayer, this.transformation, this.kw, this.kh, this.ww, this.wh, this.colorChannel, this.kernel);
    }

    @Override // io.github.beardedManZhao.algorithmStar.core.model.ASModel
    public ClassificationModel<IntegerMatrixSpace> functionConcurrency(IntegerMatrixSpace[] integerMatrixSpaceArr) {
        return function(integerMatrixSpaceArr);
    }

    @Override // io.github.beardedManZhao.algorithmStar.core.model.ASModel
    public /* bridge */ /* synthetic */ void setArg(Integer num, @NotNull Cell cell) {
        setArg2(num, (Cell<?>) cell);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1894734946:
                if (implMethodName.equals("lambda$new$a5afccb9$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("io/github/beardedManZhao/algorithmStar/utils/transformation/Transformation") && serializedLambda.getFunctionalInterfaceMethodName().equals("function") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/github/beardedManZhao/algorithmStar/core/model/SingleLayerCNNModel") && serializedLambda.getImplMethodSignature().equals("(Lio/github/beardedManZhao/algorithmStar/operands/matrix/ColorMatrix;)Lio/github/beardedManZhao/algorithmStar/operands/matrix/ColorMatrix;")) {
                    return colorMatrix -> {
                        return colorMatrix;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
