package ai.djl.nn.recurrent;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.recurrent.RNN;
import ai.djl.util.Pair;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.Iterator;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/nn/recurrent/RecurrentBlock.class */
public abstract class RecurrentBlock extends AbstractBlock {
    private static final byte VERSION = 2;
    private static final LayoutType[] EXPECTED_LAYOUT = {LayoutType.BATCH, LayoutType.TIME, LayoutType.CHANNEL};
    protected long stateSize;
    protected float dropRate;
    protected int numLayers;
    protected int gates;
    protected boolean batchFirst;
    protected boolean hasBiases;
    protected boolean bidirectional;
    protected boolean returnState;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/nn/recurrent/RecurrentBlock$BaseBuilder.class */
    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected float dropRate;
        protected long stateSize;
        protected int numLayers;
        protected boolean batchFirst = true;
        protected boolean hasBiases = true;
        protected boolean bidirectional;
        protected boolean returnState;
        protected RNN.Activation activation;

        public T optDropRate(float f) {
            this.dropRate = f;
            return self();
        }

        public T setStateSize(int i) {
            this.stateSize = i;
            return self();
        }

        public T setNumLayers(int i) {
            this.numLayers = i;
            return self();
        }

        public T optBidirectional(boolean z) {
            this.bidirectional = z;
            return self();
        }

        public T optBatchFirst(boolean z) {
            this.batchFirst = z;
            return self();
        }

        public T optHasBiases(boolean z) {
            this.hasBiases = z;
            return self();
        }

        public T optReturnState(boolean z) {
            this.returnState = z;
            return self();
        }

        protected abstract T self();
    }

    public RecurrentBlock(BaseBuilder<?> baseBuilder) {
        super((byte) 2);
        this.stateSize = baseBuilder.stateSize;
        this.dropRate = baseBuilder.dropRate;
        this.numLayers = baseBuilder.numLayers;
        this.batchFirst = baseBuilder.batchFirst;
        this.hasBiases = baseBuilder.hasBiases;
        this.bidirectional = baseBuilder.bidirectional;
        this.returnState = baseBuilder.returnState;
        Parameter.Type[] typeArr = this.hasBiases ? new Parameter.Type[]{Parameter.Type.WEIGHT, Parameter.Type.BIAS} : new Parameter.Type[]{Parameter.Type.WEIGHT};
        String[] strArr = baseBuilder.bidirectional ? new String[]{"l", "r"} : new String[]{"l"};
        String[] strArr2 = {"i2h", "h2h"};
        for (int i = 0; i < this.numLayers; i++) {
            for (Parameter.Type type : typeArr) {
                for (String str : strArr) {
                    for (String str2 : strArr2) {
                        addParameter(Parameter.builder().setName(str + '_' + i + '_' + str2 + '_' + type.name()).setType(type).build());
                    }
                }
            }
        }
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        Shape shape = shapeArr[0];
        Shape shape2 = new Shape(shape.get(0), shape.get(1), this.stateSize * getNumDirections());
        if (!this.returnState) {
            return new Shape[]{shape2};
        }
        Shape[] shapeArr2 = new Shape[2];
        shapeArr2[0] = shape2;
        long[] jArr = new long[3];
        jArr[0] = this.numLayers * getNumDirections();
        jArr[1] = shape.get(this.batchFirst ? 0 : 1);
        jArr[2] = this.stateSize;
        shapeArr2[1] = new Shape(jArr);
        return shapeArr2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.djl.nn.AbstractBaseBlock
    public void beforeInitialize(Shape... shapeArr) {
        super.beforeInitialize(shapeArr);
        Block.validateLayout(EXPECTED_LAYOUT, shapeArr[0].getLayout());
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void prepare(Shape[] shapeArr) {
        Shape shape = shapeArr[0];
        Iterator<Pair<String, Parameter>> it = getDirectParameters().iterator();
        while (it.hasNext()) {
            Pair<String, Parameter> next = it.next();
            String key = next.getKey();
            Parameter value = next.getValue();
            int parseInt = Integer.parseInt(key.split("_")[1]);
            long j = shape.get(2);
            if (parseInt > 0) {
                j = this.stateSize * getNumDirections();
            }
            if (key.contains("BIAS")) {
                value.setShape(new Shape(this.gates * this.stateSize));
            } else if (key.contains("i2h")) {
                value.setShape(new Shape(this.gates * this.stateSize, j));
            } else {
                if (!key.contains("h2h")) {
                    throw new IllegalArgumentException("Invalid parameter name");
                }
                value.setShape(new Shape(this.gates * this.stateSize, this.stateSize));
            }
        }
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == this.version) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getNumDirections() {
        return this.bidirectional ? 2 : 1;
    }
}
