/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.internal.connection.tlschannel.impl;

import com.mongodb.internal.connection.tlschannel.NeedsReadException;
import com.mongodb.internal.connection.tlschannel.NeedsTaskException;
import com.mongodb.internal.connection.tlschannel.NeedsWriteException;
import com.mongodb.internal.connection.tlschannel.TlsChannelCallbackException;
import com.mongodb.internal.connection.tlschannel.TrackingAllocator;
import com.mongodb.internal.connection.tlschannel.WouldBlockException;
import com.mongodb.internal.connection.tlschannel.impl.BufferHolder;
import com.mongodb.internal.connection.tlschannel.impl.ByteBufferSet;
import com.mongodb.internal.connection.tlschannel.util.Util;
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Optional;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;

public class TlsChannelImpl
implements ByteChannel {
    private static final Logger LOGGER = Loggers.getLogger("connection.tls");
    public static final int buffersInitialSize = 4096;
    public static final int maxTlsPacketSize = 17408;
    private final ReadableByteChannel readChannel;
    private final WritableByteChannel writeChannel;
    private final SSLEngine engine;
    private BufferHolder inEncrypted;
    private final Consumer<SSLSession> initSessionCallback;
    private final boolean runTasks;
    private final TrackingAllocator encryptedBufAllocator;
    private final TrackingAllocator plainBufAllocator;
    private final boolean waitForCloseConfirmation;
    private final Lock initLock = new ReentrantLock();
    private final Lock readLock = new ReentrantLock();
    private final Lock writeLock = new ReentrantLock();
    private volatile boolean negotiated = false;
    private volatile boolean invalid = false;
    private volatile boolean shutdownSent = false;
    private volatile boolean shutdownReceived = false;
    private BufferHolder inPlain;
    private BufferHolder outEncrypted;
    private final ByteBufferSet dummyOut = new ByteBufferSet(new ByteBuffer[]{ByteBuffer.allocate(0)});

    public TlsChannelImpl(ReadableByteChannel readableByteChannel, WritableByteChannel writableByteChannel, SSLEngine sSLEngine, Optional<BufferHolder> optional, Consumer<SSLSession> consumer, boolean bl, TrackingAllocator trackingAllocator, TrackingAllocator trackingAllocator2, boolean bl2, boolean bl3) {
        this.readChannel = readableByteChannel;
        this.writeChannel = writableByteChannel;
        this.engine = sSLEngine;
        this.inEncrypted = optional.orElseGet(() -> new BufferHolder("inEncrypted", Optional.empty(), trackingAllocator2, 4096, 17408, false, bl2));
        this.initSessionCallback = consumer;
        this.runTasks = bl;
        this.plainBufAllocator = trackingAllocator;
        this.encryptedBufAllocator = trackingAllocator2;
        this.waitForCloseConfirmation = bl3;
        this.inPlain = new BufferHolder("inPlain", Optional.empty(), trackingAllocator, 4096, 17408, true, bl2);
        this.outEncrypted = new BufferHolder("outEncrypted", Optional.empty(), trackingAllocator2, 4096, 17408, false, bl2);
    }

    public Consumer<SSLSession> getSessionInitCallback() {
        return this.initSessionCallback;
    }

    public TrackingAllocator getPlainBufferAllocator() {
        return this.plainBufAllocator;
    }

    public TrackingAllocator getEncryptedBufferAllocator() {
        return this.encryptedBufAllocator;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long read(ByteBufferSet byteBufferSet) throws IOException {
        TlsChannelImpl.checkReadBuffer(byteBufferSet);
        if (!byteBufferSet.hasRemaining()) {
            return 0L;
        }
        this.handshake();
        this.readLock.lock();
        try {
            int n;
            if (this.invalid || this.shutdownSent) {
                throw new ClosedChannelException();
            }
            SSLEngineResult.HandshakeStatus handshakeStatus = this.engine.getHandshakeStatus();
            int n2 = n = this.inPlain.nullOrEmpty() ? 0 : this.inPlain.buffer.position();
            block14: while (true) {
                if (n > 0) {
                    if (this.inPlain.nullOrEmpty()) {
                        long l = n;
                        return l;
                    }
                    long l = this.transferPendingPlain(byteBufferSet);
                    return l;
                }
                if (this.shutdownReceived) {
                    long l = -1L;
                    return l;
                }
                Util.assertTrue(this.inPlain.nullOrEmpty());
                switch (handshakeStatus) {
                    case NEED_UNWRAP: 
                    case NEED_WRAP: {
                        n = this.handshake(Optional.of(byteBufferSet), Optional.of(handshakeStatus));
                        handshakeStatus = SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
                        continue block14;
                    }
                    case NOT_HANDSHAKING: 
                    case FINISHED: {
                        UnwrapResult unwrapResult = this.readAndUnwrap(Optional.of(byteBufferSet));
                        if (unwrapResult.wasClosed) {
                            long l = -1L;
                            return l;
                        }
                        n = unwrapResult.bytesProduced;
                        handshakeStatus = unwrapResult.lastHandshakeStatus;
                        continue block14;
                    }
                    case NEED_TASK: {
                        this.handleTask();
                        handshakeStatus = this.engine.getHandshakeStatus();
                        continue block14;
                    }
                }
                break;
            }
            long l = -1L;
            return l;
        }
        catch (EofException eofException) {
            long l = -1L;
            return l;
        }
        finally {
            this.readLock.unlock();
        }
    }

    private void handleTask() throws NeedsTaskException {
        if (!this.runTasks) {
            throw new NeedsTaskException(this.engine.getDelegatedTask());
        }
        this.engine.getDelegatedTask().run();
    }

    private int transferPendingPlain(ByteBufferSet byteBufferSet) {
        ((Buffer)this.inPlain.buffer).flip();
        int n = byteBufferSet.putRemaining(this.inPlain.buffer);
        this.inPlain.buffer.compact();
        boolean bl = this.inPlain.release();
        if (!bl) {
            this.inPlain.zeroRemaining();
        }
        return n;
    }

    private UnwrapResult unwrapLoop(Optional<ByteBufferSet> optional, SSLEngineResult.HandshakeStatus handshakeStatus) throws SSLException {
        ByteBufferSet byteBufferSet = optional.orElseGet(() -> {
            this.inPlain.prepare();
            return new ByteBufferSet(this.inPlain.buffer);
        });
        while (true) {
            Util.assertTrue(this.inPlain.nullOrEmpty());
            SSLEngineResult sSLEngineResult = this.callEngineUnwrap(byteBufferSet);
            if (sSLEngineResult.bytesProduced() > 0 || sSLEngineResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW || sSLEngineResult.getStatus() == SSLEngineResult.Status.CLOSED || sSLEngineResult.getHandshakeStatus() != handshakeStatus) {
                boolean bl = sSLEngineResult.getStatus() == SSLEngineResult.Status.CLOSED;
                return new UnwrapResult(sSLEngineResult.bytesProduced(), sSLEngineResult.getHandshakeStatus(), bl);
            }
            if (sSLEngineResult.getStatus() != SSLEngineResult.Status.BUFFER_OVERFLOW) continue;
            if (optional.isPresent() && byteBufferSet == optional.get()) {
                this.inPlain.prepare();
                this.ensureInPlainCapacity(Math.min((int)optional.get().remaining() * 2, 17408));
            } else {
                this.inPlain.enlarge();
            }
            byteBufferSet = new ByteBufferSet(this.inPlain.buffer);
        }
    }

    private SSLEngineResult callEngineUnwrap(ByteBufferSet byteBufferSet) throws SSLException {
        ((Buffer)this.inEncrypted.buffer).flip();
        try {
            SSLEngineResult sSLEngineResult = this.engine.unwrap(this.inEncrypted.buffer, byteBufferSet.array, byteBufferSet.offset, byteBufferSet.length);
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace(String.format("engine.unwrap() result [%s]. Engine status: %s; inEncrypted %s; inPlain: %s", new Object[]{Util.resultToString(sSLEngineResult), sSLEngineResult.getHandshakeStatus(), this.inEncrypted, byteBufferSet}));
            }
            SSLEngineResult sSLEngineResult2 = sSLEngineResult;
            return sSLEngineResult2;
        }
        catch (SSLException sSLException) {
            this.invalid = true;
            throw sSLException;
        }
        finally {
            this.inEncrypted.buffer.compact();
        }
    }

    private int readFromChannel() throws IOException, EofException {
        try {
            return TlsChannelImpl.readFromChannel(this.readChannel, this.inEncrypted.buffer);
        }
        catch (WouldBlockException wouldBlockException) {
            throw wouldBlockException;
        }
        catch (IOException iOException) {
            this.invalid = true;
            throw iOException;
        }
    }

    public static int readFromChannel(ReadableByteChannel readableByteChannel, ByteBuffer byteBuffer) throws IOException, EofException {
        Util.assertTrue(byteBuffer.hasRemaining());
        LOGGER.trace("Reading from channel");
        int n = readableByteChannel.read(byteBuffer);
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace(String.format("Read from channel; response: %s, buffer: %s", n, byteBuffer));
        }
        if (n == -1) {
            throw new EofException();
        }
        if (n == 0) {
            throw new NeedsReadException();
        }
        return n;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long write(ByteBufferSet byteBufferSet) throws IOException {
        this.handshake();
        this.writeLock.lock();
        try {
            if (this.invalid || this.shutdownSent) {
                throw new ClosedChannelException();
            }
            long l = this.wrapAndWrite(byteBufferSet);
            return l;
        }
        finally {
            this.writeLock.unlock();
        }
    }

    private long wrapAndWrite(ByteBufferSet byteBufferSet) throws IOException {
        long l = byteBufferSet.remaining();
        long l2 = 0L;
        this.outEncrypted.prepare();
        try {
            while (true) {
                this.writeToChannel();
                if (l2 == l) {
                    long l3 = l;
                    return l3;
                }
                WrapResult wrapResult = this.wrapLoop(byteBufferSet);
                l2 += (long)wrapResult.bytesConsumed;
            }
        }
        finally {
            this.outEncrypted.release();
        }
    }

    private WrapResult wrapLoop(ByteBufferSet byteBufferSet) throws SSLException {
        while (true) {
            SSLEngineResult sSLEngineResult = this.callEngineWrap(byteBufferSet);
            switch (sSLEngineResult.getStatus()) {
                case OK: 
                case CLOSED: {
                    return new WrapResult(sSLEngineResult.bytesConsumed(), sSLEngineResult.getHandshakeStatus());
                }
                case BUFFER_OVERFLOW: {
                    Util.assertTrue(sSLEngineResult.bytesConsumed() == 0);
                    this.outEncrypted.enlarge();
                    break;
                }
                case BUFFER_UNDERFLOW: {
                    throw new IllegalStateException();
                }
            }
        }
    }

    private SSLEngineResult callEngineWrap(ByteBufferSet byteBufferSet) throws SSLException {
        try {
            SSLEngineResult sSLEngineResult = this.engine.wrap(byteBufferSet.array, byteBufferSet.offset, byteBufferSet.length, this.outEncrypted.buffer);
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace(String.format("engine.wrap() result: [%s]; engine status: %s; srcBuffer: %s, outEncrypted: %s", new Object[]{Util.resultToString(sSLEngineResult), sSLEngineResult.getHandshakeStatus(), byteBufferSet, this.outEncrypted}));
            }
            return sSLEngineResult;
        }
        catch (SSLException sSLException) {
            this.invalid = true;
            throw sSLException;
        }
    }

    private void ensureInPlainCapacity(int n) {
        if (this.inPlain.buffer.capacity() < n) {
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace(String.format("inPlain buffer too small, increasing from %s to %s", this.inPlain.buffer.capacity(), n));
            }
            this.inPlain.resize(n);
        }
    }

    private void writeToChannel() throws IOException {
        if (this.outEncrypted.buffer.position() == 0) {
            return;
        }
        ((Buffer)this.outEncrypted.buffer).flip();
        try {
            try {
                TlsChannelImpl.writeToChannel(this.writeChannel, this.outEncrypted.buffer);
            }
            catch (WouldBlockException wouldBlockException) {
                throw wouldBlockException;
            }
            catch (IOException iOException) {
                this.invalid = true;
                throw iOException;
            }
        }
        finally {
            this.outEncrypted.buffer.compact();
        }
    }

    private static void writeToChannel(WritableByteChannel writableByteChannel, ByteBuffer byteBuffer) throws IOException {
        while (byteBuffer.hasRemaining()) {
            int n;
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace("Writing to channel: " + byteBuffer);
            }
            if ((n = writableByteChannel.write(byteBuffer)) != 0) continue;
            throw new NeedsWriteException();
        }
    }

    public void renegotiate() throws IOException {
        if (this.engine.getSession().getProtocol().compareTo("TLSv1.3") >= 0) {
            throw new SSLException("renegotiation not supported in TLS 1.3 or latter");
        }
        try {
            this.doHandshake(true);
        }
        catch (EofException eofException) {
            throw new ClosedChannelException();
        }
    }

    public void handshake() throws IOException {
        try {
            this.doHandshake(false);
        }
        catch (EofException eofException) {
            throw new ClosedChannelException();
        }
    }

    private void doHandshake(boolean bl) throws IOException, EofException {
        block7: {
            if (!bl && this.negotiated) {
                return;
            }
            this.initLock.lock();
            try {
                if (this.invalid || this.shutdownSent) {
                    throw new ClosedChannelException();
                }
                if (!bl && this.negotiated) break block7;
                this.engine.beginHandshake();
                LOGGER.trace("Called engine.beginHandshake()");
                this.handshake(Optional.empty(), Optional.empty());
                try {
                    this.initSessionCallback.accept(this.engine.getSession());
                }
                catch (Exception exception) {
                    LOGGER.trace("client code threw exception in session initialization callback", exception);
                    throw new TlsChannelCallbackException("session initialization callback failed", exception);
                }
                this.negotiated = true;
            }
            finally {
                this.initLock.unlock();
            }
        }
    }

    /*
     * Exception decompiling
     */
    private int handshake(Optional<ByteBufferSet> var1_1, Optional<SSLEngineResult.HandshakeStatus> var2_2) throws IOException, EofException {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private int handshakeLoop(Optional<ByteBufferSet> optional, Optional<SSLEngineResult.HandshakeStatus> optional2) throws IOException, EofException {
        Util.assertTrue(this.inPlain.nullOrEmpty());
        SSLEngineResult.HandshakeStatus handshakeStatus = optional2.orElseGet(() -> this.engine.getHandshakeStatus());
        block7: while (true) {
            switch (handshakeStatus) {
                case NEED_WRAP: {
                    Util.assertTrue(this.outEncrypted.nullOrEmpty());
                    WrapResult wrapResult = this.wrapLoop(this.dummyOut);
                    handshakeStatus = wrapResult.lastHandshakeStatus;
                    this.writeToChannel();
                    continue block7;
                }
                case NEED_UNWRAP: {
                    UnwrapResult unwrapResult = this.readAndUnwrap(optional);
                    handshakeStatus = unwrapResult.lastHandshakeStatus;
                    if (unwrapResult.bytesProduced <= 0) continue block7;
                    return unwrapResult.bytesProduced;
                }
                case NOT_HANDSHAKING: {
                    return 0;
                }
                case NEED_TASK: {
                    this.handleTask();
                    handshakeStatus = this.engine.getHandshakeStatus();
                    continue block7;
                }
                case FINISHED: {
                    return 0;
                }
            }
            break;
        }
        return 0;
    }

    private UnwrapResult readAndUnwrap(Optional<ByteBufferSet> optional) throws IOException, EofException {
        SSLEngineResult.HandshakeStatus handshakeStatus = this.engine.getHandshakeStatus();
        this.inEncrypted.prepare();
        try {
            while (true) {
                Util.assertTrue(this.inPlain.nullOrEmpty());
                UnwrapResult unwrapResult = this.unwrapLoop(optional, handshakeStatus);
                if (unwrapResult.bytesProduced > 0 || unwrapResult.lastHandshakeStatus != handshakeStatus || unwrapResult.wasClosed) {
                    if (unwrapResult.wasClosed) {
                        this.shutdownReceived = true;
                    }
                    UnwrapResult unwrapResult2 = unwrapResult;
                    return unwrapResult2;
                }
                if (!this.inEncrypted.buffer.hasRemaining()) {
                    this.inEncrypted.enlarge();
                }
                this.readFromChannel();
            }
        }
        finally {
            this.inEncrypted.release();
        }
    }

    @Override
    public void close() throws IOException {
        this.tryShutdown();
        this.writeChannel.close();
        this.readChannel.close();
        this.readLock.lock();
        try {
            this.writeLock.lock();
            try {
                this.freeBuffers();
            }
            finally {
                this.writeLock.unlock();
            }
        }
        finally {
            this.readLock.unlock();
        }
    }

    private void tryShutdown() {
        block13: {
            if (!this.readLock.tryLock()) {
                return;
            }
            try {
                if (!this.writeLock.tryLock()) {
                    return;
                }
                try {
                    if (this.shutdownSent) break block13;
                    try {
                        boolean bl = this.shutdown();
                        if (!bl && this.waitForCloseConfirmation) {
                            this.shutdown();
                        }
                    }
                    catch (Throwable throwable) {
                        if (LOGGER.isDebugEnabled()) {
                            LOGGER.debug("error doing TLS shutdown on close(), continuing: " + throwable.getMessage());
                        }
                    }
                }
                finally {
                    this.writeLock.unlock();
                }
            }
            finally {
                this.readLock.unlock();
            }
        }
    }

    public boolean shutdown() throws IOException {
        this.readLock.lock();
        try {
            block16: {
                this.writeLock.lock();
                try {
                    if (this.invalid) {
                        throw new ClosedChannelException();
                    }
                    if (this.shutdownSent) break block16;
                    this.shutdownSent = true;
                    this.outEncrypted.prepare();
                    try {
                        this.writeToChannel();
                        this.engine.closeOutbound();
                        this.wrapLoop(this.dummyOut);
                        this.writeToChannel();
                    }
                    finally {
                        this.outEncrypted.release();
                    }
                    if (this.shutdownReceived) {
                        this.freeBuffers();
                    }
                    boolean bl = this.shutdownReceived;
                    this.writeLock.unlock();
                    return bl;
                }
                catch (Throwable throwable) {
                    this.writeLock.unlock();
                    throw throwable;
                }
            }
            if (!this.shutdownReceived) {
                try {
                    this.readAndUnwrap(Optional.empty());
                    Util.assertTrue(this.shutdownReceived);
                }
                catch (EofException eofException) {
                    throw new ClosedChannelException();
                }
            }
            this.freeBuffers();
            boolean bl = true;
            this.writeLock.unlock();
            return bl;
        }
        finally {
            this.readLock.unlock();
        }
    }

    private void freeBuffers() {
        if (this.inEncrypted != null) {
            this.inEncrypted.dispose();
            this.inEncrypted = null;
        }
        if (this.inPlain != null) {
            this.inPlain.dispose();
            this.inPlain = null;
        }
        if (this.outEncrypted != null) {
            this.outEncrypted.dispose();
            this.outEncrypted = null;
        }
    }

    @Override
    public boolean isOpen() {
        return !this.invalid && this.writeChannel.isOpen() && this.readChannel.isOpen();
    }

    public static void checkReadBuffer(ByteBufferSet byteBufferSet) {
        if (byteBufferSet.isReadOnly()) {
            throw new IllegalArgumentException();
        }
    }

    public SSLEngine engine() {
        return this.engine;
    }

    public boolean getRunTasks() {
        return this.runTasks;
    }

    @Override
    public int read(ByteBuffer byteBuffer) throws IOException {
        return (int)this.read(new ByteBufferSet(byteBuffer));
    }

    @Override
    public int write(ByteBuffer byteBuffer) throws IOException {
        return (int)this.write(new ByteBufferSet(byteBuffer));
    }

    public boolean shutdownReceived() {
        return this.shutdownReceived;
    }

    public boolean shutdownSent() {
        return this.shutdownSent;
    }

    public ReadableByteChannel plainReadableChannel() {
        return this.readChannel;
    }

    public WritableByteChannel plainWritableChannel() {
        return this.writeChannel;
    }

    private static class UnwrapResult {
        public final int bytesProduced;
        public final SSLEngineResult.HandshakeStatus lastHandshakeStatus;
        public final boolean wasClosed;

        public UnwrapResult(int n, SSLEngineResult.HandshakeStatus handshakeStatus, boolean bl) {
            this.bytesProduced = n;
            this.lastHandshakeStatus = handshakeStatus;
            this.wasClosed = bl;
        }
    }

    public static class EofException
    extends Exception {
        private static final long serialVersionUID = -3859156713994602991L;

        @Override
        public Throwable fillInStackTrace() {
            return this;
        }
    }

    private static class WrapResult {
        public final int bytesConsumed;
        public final SSLEngineResult.HandshakeStatus lastHandshakeStatus;

        public WrapResult(int n, SSLEngineResult.HandshakeStatus handshakeStatus) {
            this.bytesConsumed = n;
            this.lastHandshakeStatus = handshakeStatus;
        }
    }
}

