package ai.djl.engine.rust;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.NativeResource;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.stream.IntStream;

/* loaded from: input_file:META-INF/jars/tokenizers-0.31.1.jar:ai/djl/engine/rust/RsNDArray.class */
public class RsNDArray extends NativeResource<Long> implements NDArray {
    private String name;
    private Device device;
    private DataType dataType;
    private Shape shape;
    private RsNDManager manager;
    private RsNDArrayEx ndArrayEx;
    private ByteBuffer dataRef;

    public RsNDArray(RsNDManager rsNDManager, long j) {
        this(rsNDManager, j, null, null);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RsNDArray(RsNDManager rsNDManager, long j, DataType dataType) {
        this(rsNDManager, j, dataType, null);
    }

    public RsNDArray(RsNDManager rsNDManager, long j, DataType dataType, ByteBuffer byteBuffer) {
        super(Long.valueOf(j));
        this.dataType = dataType;
        this.manager = rsNDManager;
        this.ndArrayEx = new RsNDArrayEx(this);
        this.dataRef = byteBuffer;
        rsNDManager.attachInternal(getUid(), this);
        NDScope.register(this);
    }

    @Override // ai.djl.ndarray.NDResource
    public RsNDManager getManager() {
        return this.manager;
    }

    @Override // ai.djl.ndarray.NDArray
    public String getName() {
        return this.name;
    }

    @Override // ai.djl.ndarray.NDArray
    public void setName(String str) {
        this.name = str;
    }

    @Override // ai.djl.ndarray.NDArray
    public DataType getDataType() {
        if (this.dataType == null) {
            this.dataType = DataType.values()[RustLibrary.getDataType(getHandle().longValue())];
        }
        return this.dataType;
    }

    @Override // ai.djl.ndarray.NDArray
    public Device getDevice() {
        String str;
        if (this.device == null) {
            int[] device = RustLibrary.getDevice(getHandle().longValue());
            switch (device[0]) {
                case 0:
                    str = Device.Type.CPU;
                    break;
                case 1:
                    str = Device.Type.GPU;
                    break;
                case 2:
                    str = "mps";
                    break;
                default:
                    throw new EngineException("Unknown device type: " + device[0]);
            }
            this.device = Device.of(str, device[1]);
        }
        return this.device;
    }

    @Override // ai.djl.ndarray.NDArray
    public Shape getShape() {
        if (this.shape == null) {
            this.shape = new Shape(RustLibrary.getShape(getHandle().longValue()));
        }
        return this.shape;
    }

    @Override // ai.djl.ndarray.NDArray
    public SparseFormat getSparseFormat() {
        return SparseFormat.DENSE;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray toDevice(Device device, boolean z) {
        if (device.equals(getDevice()) && !z) {
            return this;
        }
        return toArray(RustLibrary.toDevice(getHandle().longValue(), device.getDeviceType(), device.getDeviceId()), null, false, true);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray toType(DataType dataType, boolean z) {
        if (dataType.equals(getDataType()) && !z) {
            return this;
        }
        if (dataType == DataType.BOOLEAN) {
            return toArray(RustLibrary.toBoolean(getHandle().longValue()), dataType, false, true);
        }
        if (this.dataType == DataType.INT64 && dataType == DataType.FLOAT16 && getDevice().isGpu()) {
            throw new UnsupportedOperationException("FP16 to I64 is not supported on GPU.");
        }
        return toArray(RustLibrary.toDataType(getHandle().longValue(), this.manager.toRustDataType(dataType)), dataType, false, true);
    }

    @Override // ai.djl.ndarray.NDArray
    public void setRequiresGradient(boolean z) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray getGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public boolean hasGradient() {
        return false;
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray stopGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public ByteBuffer toByteBuffer(boolean z) {
        ByteBuffer wrap = ByteBuffer.wrap(RustLibrary.toByteArray(getHandle().longValue()));
        wrap.order(ByteOrder.nativeOrder());
        return wrap;
    }

    @Override // ai.djl.ndarray.NDArray
    public String[] toStringArray(Charset charset) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public void set(Buffer buffer) {
        int intExact = Math.toIntExact(size());
        DataType dataType = getDataType();
        BaseNDManager.validateBuffer(buffer, dataType, intExact);
        this.dataRef = null;
        if (buffer.isDirect() && (buffer instanceof ByteBuffer)) {
            if (!getDevice().isGpu()) {
                this.dataRef = (ByteBuffer) buffer;
            }
            intern(this.manager.create(buffer, getShape(), dataType).toDevice(getDevice(), false));
        } else {
            ByteBuffer allocateDirect = this.manager.allocateDirect(intExact * dataType.getNumOfBytes());
            BaseNDManager.copyBuffer(buffer, allocateDirect);
            if (!getDevice().isGpu()) {
                this.dataRef = allocateDirect;
            }
            intern(this.manager.create((Buffer) allocateDirect, getShape(), dataType).toDevice(getDevice(), false));
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray gather(NDArray nDArray, int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray gatherNd(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray take(NDManager nDManager, NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDManager rsNDManager = (RsNDManager) nDManager;
            RsNDArray rsNDArray = new RsNDArray(rsNDManager, RustLibrary.take(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()));
            NDScope.unregister(rsNDArray);
            nDScope.close();
            return rsNDArray;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray put(NDArray nDArray, NDArray nDArray2) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.put(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue(), this.manager.from(nDArray2).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray scatter(NDArray nDArray, NDArray nDArray2, int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDResource
    public void attach(NDManager nDManager) {
        detach();
        this.manager = (RsNDManager) nDManager;
        nDManager.attachInternal(getUid(), this);
    }

    @Override // ai.djl.ndarray.NDResource
    public void returnResource(NDManager nDManager) {
        detach();
        this.manager = (RsNDManager) nDManager;
        nDManager.attachUncappedInternal(getUid(), this);
    }

    @Override // ai.djl.ndarray.NDResource
    public void tempAttach(NDManager nDManager) {
        RsNDManager rsNDManager = this.manager;
        detach();
        this.manager = (RsNDManager) nDManager;
        nDManager.tempAttachInternal(rsNDManager, getUid(), this);
    }

    @Override // ai.djl.ndarray.NDResource
    public void detach() {
        this.manager.detachInternal(getUid());
        this.manager = RsNDManager.getSystemManager();
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray duplicate() {
        return toArray(RustLibrary.duplicate(getHandle().longValue()), this.dataType, false, true);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray booleanMask(NDArray nDArray, int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray sequenceMask(NDArray nDArray, float f) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray sequenceMask(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    @Override // ai.djl.ndarray.NDArray
    public boolean contentEquals(Number number) {
        return contentEquals(this.manager.create(number));
    }

    @Override // ai.djl.ndarray.NDArray
    public boolean contentEquals(NDArray nDArray) {
        if (nDArray != null && shapeEquals(nDArray) && getDataType() == nDArray.getDataType()) {
            return RustLibrary.contentEqual(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue());
        }
        return false;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray eq(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray eq = eq(create);
            if (create != null) {
                create.close();
            }
            return eq;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray eq(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.eq(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), DataType.BOOLEAN, true, false);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray neq(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray neq = neq(create);
            if (create != null) {
                create.close();
            }
            return neq;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray neq(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.neq(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), DataType.BOOLEAN, true, false);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray gt(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray gt = gt(create);
            if (create != null) {
                create.close();
            }
            return gt;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray gt(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.gt(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), DataType.BOOLEAN, true, false);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray gte(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray gte = gte(create);
            if (create != null) {
                create.close();
            }
            return gte;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray gte(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.gte(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), DataType.BOOLEAN, true, false);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray lt(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray lt = lt(create);
            if (create != null) {
                create.close();
            }
            return lt;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray lt(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.lt(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), DataType.BOOLEAN, true, false);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray lte(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray lte = lte(create);
            if (create != null) {
                create.close();
            }
            return lte;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray lte(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.lte(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), DataType.BOOLEAN, true, false);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray add(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray add = add(create);
            if (create != null) {
                create.close();
            }
            return add;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray add(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.add(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sub(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray sub = sub(create);
            if (create != null) {
                create.close();
            }
            return sub;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sub(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.sub(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray mul(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray mul = mul(create);
            if (create != null) {
                create.close();
            }
            return mul;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray mul(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.mul(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray div(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray div = div(create);
            if (create != null) {
                create.close();
            }
            return div;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray div(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.div(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray mod(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray mod = mod(create);
            if (create != null) {
                create.close();
            }
            return mod;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray mod(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.remainder(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray pow(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray pow = pow(create);
            if (create != null) {
                create.close();
            }
            return pow;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray pow(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.pow(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray xlogy(NDArray nDArray) {
        if (isScalar() || nDArray.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for xlogy()");
        }
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.xlogy(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray addi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray addi = addi(create);
            if (create != null) {
                create.close();
            }
            return addi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray addi(NDArray nDArray) {
        intern(add(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray subi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray subi = subi(create);
            if (create != null) {
                create.close();
            }
            return subi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray subi(NDArray nDArray) {
        intern(sub(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray muli(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray muli = muli(create);
            if (create != null) {
                create.close();
            }
            return muli;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray muli(NDArray nDArray) {
        intern(mul(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray divi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray divi = divi(create);
            if (create != null) {
                create.close();
            }
            return divi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray divi(NDArray nDArray) {
        intern(div(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray modi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray modi = modi(create);
            if (create != null) {
                create.close();
            }
            return modi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray modi(NDArray nDArray) {
        intern(mod(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray powi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray powi = powi(create);
            if (create != null) {
                create.close();
            }
            return powi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray powi(NDArray nDArray) {
        intern(pow(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray signi() {
        intern(sign());
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray negi() {
        intern(neg());
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sign() {
        return toArray(RustLibrary.sign(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray maximum(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray maximum = maximum(create);
            if (create != null) {
                create.close();
            }
            return maximum;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray maximum(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.maximum(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray minimum(Number number) {
        NDArray create = this.manager.create(number);
        try {
            RsNDArray minimum = minimum(create);
            if (create != null) {
                create.close();
            }
            return minimum;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray minimum(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.minimum(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray all() {
        NDArray countNonzero = countNonzero();
        RsNDArray rsNDArray = (RsNDArray) this.manager.create(countNonzero.getLong(new long[0]) == size());
        countNonzero.close();
        return rsNDArray;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray any() {
        NDArray countNonzero = countNonzero();
        RsNDArray rsNDArray = (RsNDArray) this.manager.create(countNonzero.getLong(new long[0]) > 0);
        countNonzero.close();
        return rsNDArray;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray none() {
        NDArray countNonzero = countNonzero();
        RsNDArray rsNDArray = (RsNDArray) this.manager.create(countNonzero.getLong(new long[0]) == 0);
        countNonzero.close();
        return rsNDArray;
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray countNonzero() {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.countNonzero(getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray countNonzero(int i) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.countNonzeroWithAxis(getHandle().longValue(), i), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray neg() {
        return toArray(RustLibrary.neg(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray abs() {
        return toArray(RustLibrary.abs(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray square() {
        return toArray(RustLibrary.square(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray sqrt() {
        return toArray(RustLibrary.sqrt(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray cbrt() {
        RsNDArray rsNDArray = (RsNDArray) this.manager.create(0.3333333333333333d);
        try {
            RsNDArray array = toArray(RustLibrary.pow(getHandle().longValue(), rsNDArray.getHandle().longValue()), true);
            if (rsNDArray != null) {
                rsNDArray.close();
            }
            return array;
        } catch (Throwable th) {
            if (rsNDArray != null) {
                try {
                    rsNDArray.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray floor() {
        return toArray(RustLibrary.floor(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray ceil() {
        return toArray(RustLibrary.ceil(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray round() {
        return toArray(RustLibrary.round(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray trunc() {
        return toArray(RustLibrary.trunc(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray exp() {
        return toArray(RustLibrary.exp(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray gammaln() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray log() {
        return toArray(RustLibrary.log(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray log10() {
        return toArray(RustLibrary.log10(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray log2() {
        return toArray(RustLibrary.log2(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sin() {
        return toArray(RustLibrary.sin(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray cos() {
        return toArray(RustLibrary.cos(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray tan() {
        return toArray(RustLibrary.tan(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray asin() {
        return toArray(RustLibrary.asin(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray acos() {
        return toArray(RustLibrary.acos(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray atan() {
        return toArray(RustLibrary.atan(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray atan2(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.atan2(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sinh() {
        return toArray(RustLibrary.sinh(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray cosh() {
        return toArray(RustLibrary.cosh(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray tanh() {
        return toArray(RustLibrary.tanh(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray asinh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray acosh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray atanh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray toDegrees() {
        return mul((Number) Double.valueOf(180.0d)).div((Number) Double.valueOf(3.141592653589793d));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray toRadians() {
        return mul((Number) Double.valueOf(3.141592653589793d)).div((Number) Double.valueOf(180.0d));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray max() {
        return isScalar() ? this : toArray(RustLibrary.max(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray max(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return toArray(RustLibrary.maxWithAxis(getHandle().longValue(), iArr[0], z));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray min() {
        return isScalar() ? this : toArray(RustLibrary.min(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray min(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return toArray(RustLibrary.minWithAxis(getHandle().longValue(), iArr[0], z));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sum() {
        return isScalar() ? this : toArray(RustLibrary.sum(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sum(int[] iArr, boolean z) {
        return toArray(RustLibrary.sumWithAxis(getHandle().longValue(), iArr, z));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray cumProd(int i) {
        return toArray(RustLibrary.cumProd(getHandle().longValue(), i));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray cumProd(int i, DataType dataType) {
        return toArray(RustLibrary.cumProdWithType(getHandle().longValue(), i, dataType.ordinal()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray prod() {
        return toArray(RustLibrary.prod(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray prod(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return toArray(RustLibrary.cumProdWithAxis(getHandle().longValue(), iArr[0], z));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray mean() {
        return toArray(RustLibrary.mean(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray mean(int[] iArr, boolean z) {
        return toArray(RustLibrary.meanWithAxis(getHandle().longValue(), iArr, z));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray normalize(double d, long j, double d2) {
        return toArray(RustLibrary.normalize(getHandle().longValue(), d, j, d2));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray rotate90(int i, int[] iArr) {
        if (iArr.length != 2) {
            throw new IllegalArgumentException("Axes must be 2");
        }
        return toArray(RustLibrary.rot90(getHandle().longValue(), i, iArr));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray trace(int i, int i2, int i3) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDList split(long[] jArr, int i) {
        if (jArr.length == 0) {
            return new NDList(this);
        }
        long j = getShape().get(i);
        if (jArr[jArr.length - 1] != j) {
            long[] jArr2 = new long[jArr.length + 1];
            System.arraycopy(jArr, 0, jArr2, 0, jArr.length);
            jArr2[jArr.length] = j;
            jArr = jArr2;
        }
        return toList(RustLibrary.split(getHandle().longValue(), jArr, i));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray flatten() {
        return toArray(RustLibrary.flatten(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray flatten(int i, int i2) {
        return toArray(RustLibrary.flattenWithDims(getHandle().longValue(), i, i2));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray fft(long j, long j2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray rfft(long j, long j2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray ifft(long j, long j2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray irfft(long j, long j2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray stft(long j, long j2, boolean z, NDArray nDArray, boolean z2, boolean z3) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray fft2(long[] jArr, long[] jArr2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray pad(Shape shape, double d) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray ifft2(long[] jArr, long[] jArr2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray reshape(Shape shape) {
        long j = 1;
        int i = -1;
        long[] shape2 = shape.getShape();
        for (int i2 = 0; i2 < shape2.length; i2++) {
            if (shape2[i2] >= 0) {
                j *= shape2[i2];
            } else {
                if (i != -1) {
                    throw new IllegalArgumentException("only 1 negative axis is allowed");
                }
                i = i2;
            }
        }
        if (i != -1) {
            long size = getShape().size();
            if (size % j != 0) {
                throw new IllegalArgumentException("unsupported dimensions");
            }
            shape2[i] = size / j;
        }
        return toArray(RustLibrary.reshape(getHandle().longValue(), shape.getShape()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray expandDims(int i) {
        return toArray(RustLibrary.expandDims(getHandle().longValue(), i));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray squeeze(int[] iArr) {
        return toArray(RustLibrary.squeeze(getHandle().longValue(), iArr));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDList unique(Integer num, boolean z, boolean z2, boolean z3) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray logicalAnd(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.logicalAnd(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray logicalOr(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.logicalOr(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray logicalXor(NDArray nDArray) {
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.logicalXor(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray logicalNot() {
        return toArray(RustLibrary.logicalNot(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray argSort(int i, boolean z) {
        return toArray(RustLibrary.argSort(getHandle().longValue(), i, z));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sort() {
        return sort(-1);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray sort(int i) {
        return toArray(RustLibrary.sort(getHandle().longValue(), i, false));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray softmax(int i) {
        return (getShape().isScalar() || this.shape.size() == 0) ? (RsNDArray) duplicate() : toArray(RustLibrary.softmax(getHandle().longValue(), i));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray logSoftmax(int i) {
        return toArray(RustLibrary.logSoftmax(getHandle().longValue(), i));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray cumSum() {
        return isScalar() ? (RsNDArray) reshape(1) : isEmpty() ? (RsNDArray) reshape(0) : cumSum(0);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray cumSum(int i) {
        if (getShape().dimension() > 3) {
            throw new UnsupportedOperationException("Only 3 dimensions or less is supported");
        }
        return toArray(RustLibrary.cumSum(getHandle().longValue(), i));
    }

    @Override // ai.djl.ndarray.NDArray
    public void intern(NDArray nDArray) {
        RsNDArray rsNDArray = (RsNDArray) nDArray;
        RustLibrary.deleteTensor(((Long) this.handle.getAndSet((Long) rsNDArray.handle.getAndSet(null))).longValue());
        rsNDArray.close();
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray isInfinite() {
        return toArray(RustLibrary.isInf(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray isNaN() {
        return toArray(RustLibrary.isNaN(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray tile(long j) {
        if (isEmpty()) {
            return (RsNDArray) duplicate();
        }
        long[] jArr = new long[isScalar() ? 1 : getShape().dimension()];
        Arrays.fill(jArr, j);
        return tile(jArr);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray tile(int i, long j) {
        return toArray(RustLibrary.tileWithAxis(getHandle().longValue(), i, j));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray tile(long[] jArr) {
        return toArray(RustLibrary.tile(getHandle().longValue(), jArr));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray tile(Shape shape) {
        return toArray(RustLibrary.tileWithShape(getHandle().longValue(), shape.getShape()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray repeat(long j) {
        if (isEmpty()) {
            return (RsNDArray) duplicate();
        }
        long[] jArr = new long[isScalar() ? 1 : getShape().dimension()];
        Arrays.fill(jArr, j);
        return repeat(jArr);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray repeat(int i, long j) {
        return toArray(RustLibrary.repeat(getHandle().longValue(), j, i));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray repeat(long[] jArr) {
        RsNDArray rsNDArray = this;
        for (int i = 0; i < jArr.length; i++) {
            RsNDArray rsNDArray2 = rsNDArray;
            rsNDArray = rsNDArray.repeat(i, jArr[i]);
            if (rsNDArray2 != this) {
                rsNDArray2.close();
            }
        }
        return rsNDArray;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray repeat(Shape shape) {
        return repeat(repeatsToMatchShape(shape));
    }

    private long[] repeatsToMatchShape(Shape shape) {
        Shape shape2 = getShape();
        int dimension = shape2.dimension();
        if (shape.dimension() > dimension) {
            throw new IllegalArgumentException("The desired shape has too many dimensions");
        }
        if (shape.dimension() < dimension) {
            shape = shape2.slice(0, dimension - shape.dimension()).addAll(shape);
        }
        long[] jArr = new long[dimension];
        for (int i = 0; i < dimension; i++) {
            if (shape2.get(i) == 0 || shape.get(i) % shape2.get(i) != 0) {
                throw new IllegalArgumentException("The desired shape is not a multiple of the original shape");
            }
            jArr[i] = Math.round(Math.ceil(shape.get(i) / shape2.get(i)));
        }
        return jArr;
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray dot(NDArray nDArray) {
        int dimension = getShape().dimension();
        if (dimension != nDArray.getShape().dimension() || dimension > 2) {
            throw new UnsupportedOperationException("Dimension mismatch or dimension is greater than 2.  Dot product is only applied on two 1D vectors. For high dimensions, please use .matMul instead.");
        }
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.dot(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray matMul(NDArray nDArray) {
        if (getShape().dimension() < 2 || getShape().dimension() < 2) {
            throw new IllegalArgumentException("only 2d tensors are supported for matMul()");
        }
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.matmul(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray batchMatMul(NDArray nDArray) {
        if (getShape().dimension() != 3 || getShape().dimension() != 3) {
            throw new IllegalArgumentException("only 3d tensors are allowed for batchMatMul()");
        }
        NDScope nDScope = new NDScope();
        try {
            RsNDArray array = toArray(RustLibrary.batchMatMul(getHandle().longValue(), this.manager.from(nDArray).getHandle().longValue()), true);
            nDScope.close();
            return array;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray clip(Number number, Number number2) {
        return toArray(RustLibrary.clip(getHandle().longValue(), number.doubleValue(), number2.doubleValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray swapAxes(int i, int i2) {
        return toArray(RustLibrary.transpose(getHandle().longValue(), i, i2));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray flip(int... iArr) {
        return toArray(RustLibrary.flip(getHandle().longValue(), iArr));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray transpose() {
        int dimension = getShape().dimension();
        return transpose(IntStream.range(0, dimension).map(i -> {
            return (dimension - i) - 1;
        }).toArray());
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray transpose(int... iArr) {
        if (!isScalar() || iArr.length <= 0) {
            return toArray(RustLibrary.permute(getHandle().longValue(), iArr));
        }
        throw new IllegalArgumentException("axes don't match NDArray");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray broadcast(Shape shape) {
        return toArray(RustLibrary.broadcast(getHandle().longValue(), shape.getShape()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray argMax() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMax of an empty NDArray");
        }
        return isScalar() ? (RsNDArray) this.manager.create(0L) : toArray(RustLibrary.argMax(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray argMax(int i) {
        return isScalar() ? (RsNDArray) this.manager.create(0L) : toArray(RustLibrary.argMaxWithAxis(getHandle().longValue(), i, false));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDList topK(int i, int i2, boolean z, boolean z2) {
        return toList(RustLibrary.topK(getHandle().longValue(), i, i2, z, z2));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray argMin() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return isScalar() ? (RsNDArray) this.manager.create(0L) : toArray(RustLibrary.argMin(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray argMin(int i) {
        return isScalar() ? (RsNDArray) this.manager.create(0L) : toArray(RustLibrary.argMinWithAxis(getHandle().longValue(), i, false));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray percentile(Number number) {
        return toArray(RustLibrary.percentile(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray percentile(Number number, int[] iArr) {
        return toArray(RustLibrary.percentileWithAxes(getHandle().longValue(), number.doubleValue(), iArr));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray median() {
        return median(new int[]{-1});
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray median(int[] iArr) {
        if (iArr.length != 1) {
            throw new UnsupportedOperationException("Not supporting zero or multi-dimension median");
        }
        NDList list = toList(RustLibrary.median(getHandle().longValue(), iArr[0], false));
        list.get(1).close();
        return (RsNDArray) list.get(0);
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray toDense() {
        return (RsNDArray) duplicate();
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray toSparse(SparseFormat sparseFormat) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray nonzero() {
        return toArray(RustLibrary.nonZero(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray erfinv() {
        return toArray(RustLibrary.erfinv(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray erf() {
        return toArray(RustLibrary.erf(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArray inverse() {
        return toArray(RustLibrary.inverse(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray norm(boolean z) {
        return toArray(RustLibrary.norm(getHandle().longValue(), 2, new int[0], z));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray norm(int i, int[] iArr, boolean z) {
        return toArray(RustLibrary.norm(getHandle().longValue(), i, iArr, z));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray oneHot(int i, float f, float f2, DataType dataType) {
        return toArray(RustLibrary.oneHot(getHandle().longValue(), i, f, f2, dataType.ordinal()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray batchDot(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray complex() {
        return toArray(RustLibrary.complex(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray real() {
        return toArray(RustLibrary.real(getHandle().longValue()));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray conj() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public RsNDArrayEx getNDArrayInternal() {
        if (this.ndArrayEx == null) {
            throw new UnsupportedOperationException("NDArray operation is not supported for String tensor");
        }
        return this.ndArrayEx;
    }

    public String toString() {
        return isReleased() ? "This array is already closed" : toDebugString();
    }

    public boolean equals(Object obj) {
        if (obj instanceof NDArray) {
            return contentEquals((NDArray) obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    @Override // ai.djl.util.NativeResource, java.lang.AutoCloseable, ai.djl.ndarray.NDArray, ai.djl.ndarray.NDResource
    public void close() {
        onClose();
        Long l = (Long) this.handle.getAndSet(null);
        if (l != null && l.longValue() != -1) {
            RustLibrary.deleteTensor(l.longValue());
        }
        this.manager.detachInternal(getUid());
        this.dataRef = null;
    }

    private RsNDArray toArray(long j) {
        return toArray(j, false);
    }

    private RsNDArray toArray(long j, boolean z) {
        return toArray(j, null, z, false);
    }

    private RsNDArray toArray(long j, DataType dataType, boolean z, boolean z2) {
        RsNDArray rsNDArray = new RsNDArray(this.manager, j, dataType);
        if (z2) {
            rsNDArray.setName(getName());
        }
        if (z) {
            NDScope.unregister(rsNDArray);
        }
        return rsNDArray;
    }

    private NDList toList(long[] jArr) {
        NDList nDList = new NDList(jArr.length);
        for (long j : jArr) {
            nDList.add(new RsNDArray(this.manager, j));
        }
        return nDList;
    }
}
