package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/nn/LambdaBlock.class */
public class LambdaBlock extends AbstractBlock {
    public static final String DEFAULT_NAME = "anonymous";
    private static final byte VERSION = 2;
    private Function<NDList, NDList> lambda;
    private String name;

    public LambdaBlock(Function<NDList, NDList> function) {
        this(function, DEFAULT_NAME);
    }

    public LambdaBlock(Function<NDList, NDList> function, String str) {
        super((byte) 2);
        this.lambda = function;
        this.name = str;
    }

    public String getName() {
        return this.name;
    }

    public static LambdaBlock singleton(Function<NDArray, NDArray> function) {
        return new LambdaBlock(nDList -> {
            return new NDList((NDArray) function.apply(nDList.singletonOrThrow()));
        }, function.getClass().getSimpleName());
    }

    public static LambdaBlock singleton(Function<NDArray, NDArray> function, String str) {
        return new LambdaBlock(nDList -> {
            return new NDList((NDArray) function.apply(nDList.singletonOrThrow()));
        }, str);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return this.lambda.apply(nDList);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        NDManager newBaseManager = NDManager.newBaseManager();
        try {
            NDList nDList = new NDList(shapeArr.length);
            for (Shape shape : shapeArr) {
                nDList.add(newBaseManager.zeros(shape));
            }
            NDList apply = this.lambda.apply(nDList);
            Shape[] shapeArr2 = new Shape[apply.size()];
            DataType[] dataTypeArr = new DataType[apply.size()];
            for (int i = 0; i < apply.size(); i++) {
                shapeArr2[i] = apply.get(i).getShape();
                dataTypeArr[i] = apply.get(i).getDataType();
            }
            this.outputDataTypes = dataTypeArr;
            if (newBaseManager != null) {
                newBaseManager.close();
            }
            return shapeArr2;
        } catch (Throwable th) {
            if (newBaseManager != null) {
                try {
                    newBaseManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte == 2) {
            readInputShapes(dataInputStream);
        } else if (readByte != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
    }
}
