package ai.djl.nn.norm;

import ai.djl.ndarray.NDList;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.training.ParameterStore;
import ai.djl.translate.Batchifier;
import ai.djl.translate.StackBatchifier;
import ai.djl.util.PairList;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/nn/norm/GhostBatchNorm.class */
public class GhostBatchNorm extends BatchNorm {
    private int virtualBatchSize;
    private Batchifier batchifier;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/nn/norm/GhostBatchNorm$Builder.class */
    public static class Builder extends BatchNorm.BaseBuilder<Builder> {
        private int virtualBatchSize = 128;

        Builder() {
        }

        public Builder optVirtualBatchSize(int i) {
            this.virtualBatchSize = i;
            return this;
        }

        @Override // ai.djl.nn.norm.BatchNorm.BaseBuilder
        public GhostBatchNorm build() {
            return new GhostBatchNorm(this);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.norm.BatchNorm.BaseBuilder
        public Builder self() {
            return this;
        }
    }

    protected GhostBatchNorm(Builder builder) {
        super(builder);
        this.virtualBatchSize = builder.virtualBatchSize;
        this.batchifier = new StackBatchifier();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.djl.nn.norm.BatchNorm, ai.djl.nn.AbstractBaseBlock
    public NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList[] split = split(nDList);
        for (int i = 0; i < split.length; i++) {
            split[i] = super.forwardInternal(parameterStore, split[i], z, pairList);
        }
        return batchify(split);
    }

    protected NDList[] split(NDList nDList) {
        return this.batchifier.split(nDList, (int) Math.ceil(nDList.head().size(0) / this.virtualBatchSize), true);
    }

    protected NDList batchify(NDList[] nDListArr) {
        return squeezeExtraDimensions(this.batchifier.batchify(nDListArr));
    }

    protected NDList squeezeExtraDimensions(NDList nDList) {
        nDList.set(0, nDList.singletonOrThrow().squeeze(0));
        return nDList;
    }

    public static Builder builder() {
        return new Builder();
    }
}
