/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.client.internal;

import com.mongodb.ServerAddress;
import com.mongodb.assertions.Assertions;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.connection.SslHelper;
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

class KeyManagementService {
    private static final Logger LOGGER = Loggers.getLogger("client");
    private final Map<String, SSLContext> kmsProviderSslContextMap;
    private final int timeoutMillis;

    KeyManagementService(Map<String, SSLContext> map, int n) {
        this.kmsProviderSslContextMap = Assertions.notNull("kmsProviderSslContextMap", map);
        this.timeoutMillis = n;
    }

    public InputStream stream(String string, String string2, ByteBuffer byteBuffer, @Nullable Timeout timeout) throws IOException {
        ServerAddress serverAddress = new ServerAddress(string2);
        LOGGER.info("Connecting to KMS server at " + serverAddress);
        SSLContext sSLContext = this.kmsProviderSslContextMap.get(string);
        SocketFactory socketFactory = sSLContext == null ? SSLSocketFactory.getDefault() : sSLContext.getSocketFactory();
        SSLSocket sSLSocket = (SSLSocket)socketFactory.createSocket();
        this.enableHostNameVerification(sSLSocket);
        try {
            sSLSocket.setSoTimeout(this.timeoutMillis);
            sSLSocket.connect(new InetSocketAddress(InetAddress.getByName(serverAddress.getHost()), serverAddress.getPort()), this.timeoutMillis);
        }
        catch (IOException iOException) {
            this.closeSocket(sSLSocket);
            throw iOException;
        }
        try {
            OutputStream outputStream = sSLSocket.getOutputStream();
            byte[] byArray = new byte[byteBuffer.remaining()];
            byteBuffer.get(byArray);
            outputStream.write(byArray);
        }
        catch (IOException iOException) {
            this.closeSocket(sSLSocket);
            throw iOException;
        }
        try {
            return OperationTimeoutAwareInputStream.wrapIfNeeded(timeout, sSLSocket);
        }
        catch (IOException iOException) {
            this.closeSocket(sSLSocket);
            throw iOException;
        }
    }

    private void enableHostNameVerification(SSLSocket sSLSocket) {
        SSLParameters sSLParameters = sSLSocket.getSSLParameters();
        if (sSLParameters == null) {
            sSLParameters = new SSLParameters();
        }
        SslHelper.enableHostNameVerification(sSLParameters);
        sSLSocket.setSSLParameters(sSLParameters);
    }

    private void closeSocket(Socket socket) {
        try {
            socket.close();
        }
        catch (IOException | RuntimeException exception) {
            // empty catch block
        }
    }

    private static final class OperationTimeoutAwareInputStream
    extends InputStream {
        private final Socket socket;
        private final Timeout operationTimeout;
        private final InputStream wrapped;

        private OperationTimeoutAwareInputStream(Socket socket, Timeout timeout) throws IOException {
            this.socket = socket;
            this.operationTimeout = timeout;
            this.wrapped = socket.getInputStream();
        }

        public static InputStream wrapIfNeeded(@Nullable Timeout timeout, SSLSocket sSLSocket) throws IOException {
            return Timeout.nullAsInfinite(timeout).checkedCall(TimeUnit.NANOSECONDS, () -> sSLSocket.getInputStream(), l -> new OperationTimeoutAwareInputStream(sSLSocket, Assertions.assertNotNull(timeout)), () -> new OperationTimeoutAwareInputStream(sSLSocket, Assertions.assertNotNull(timeout)));
        }

        private void setSocketSoTimeoutToOperationTimeout() throws SocketException {
            this.operationTimeout.checkedRun(TimeUnit.MILLISECONDS, () -> {
                throw new AssertionError((Object)"operationTimeout cannot be infinite");
            }, l -> this.socket.setSoTimeout(Math.toIntExact(l)), () -> TimeoutContext.throwMongoTimeoutException("Reading from KMS server exceeded the timeout limit."));
        }

        @Override
        public int read() throws IOException {
            this.setSocketSoTimeoutToOperationTimeout();
            return this.wrapped.read();
        }

        @Override
        public int read(@NonNull byte[] byArray) throws IOException {
            this.setSocketSoTimeoutToOperationTimeout();
            return this.wrapped.read(byArray);
        }

        @Override
        public int read(@NonNull byte[] byArray, int n, int n2) throws IOException {
            this.setSocketSoTimeoutToOperationTimeout();
            return this.wrapped.read(byArray, n, n2);
        }

        @Override
        public void close() throws IOException {
            this.wrapped.close();
        }

        @Override
        public long skip(long l) throws IOException {
            this.setSocketSoTimeoutToOperationTimeout();
            return this.wrapped.skip(l);
        }

        @Override
        public int available() throws IOException {
            return this.wrapped.available();
        }

        @Override
        public synchronized void mark(int n) {
            this.wrapped.mark(n);
        }

        @Override
        public synchronized void reset() throws IOException {
            this.wrapped.reset();
        }

        @Override
        public boolean markSupported() {
            return this.wrapped.markSupported();
        }
    }
}

