/*
 * Decompiled with CFR 0.152.
 */
package pl.skidam.automodpack_core.protocol;

import com.github.luben.zstd.Zstd;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.file.Path;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import pl.skidam.automodpack_core.GlobalVariables;
import pl.skidam.automodpack_core.auth.Secrets;
import pl.skidam.automodpack_core.callbacks.IntCallback;
import pl.skidam.automodpack_core.protocol.NetUtils;

class Connection {
    private static final byte PROTOCOL_VERSION = 1;
    private final boolean useCompression;
    private final byte[] secretBytes;
    private final SSLSocket socket;
    private final DataInputStream in;
    private final DataOutputStream out;
    private final ExecutorService executor = Executors.newSingleThreadExecutor();
    private final AtomicBoolean busy = new AtomicBoolean(false);

    public boolean isActive() {
        return !this.socket.isClosed();
    }

    public Connection(InetSocketAddress address, Secrets.Secret secret) throws Exception {
        try {
            GlobalVariables.LOGGER.debug("Initializing connection to: {}", (Object)address.getHostString());
            Socket plainSocket = new Socket();
            plainSocket.connect(address, 15000);
            plainSocket.setSoTimeout(15000);
            DataOutputStream plainOut = new DataOutputStream(plainSocket.getOutputStream());
            DataInputStream plainIn = new DataInputStream(plainSocket.getInputStream());
            plainOut.writeInt(1095585091);
            plainOut.flush();
            int handshakeResponse = plainIn.readInt();
            if (handshakeResponse != 1095585611) {
                plainSocket.close();
                throw new IOException("Invalid handshake response from server: " + handshakeResponse);
            }
            SSLContext context = this.createSSLContext();
            SSLSocketFactory factory = context.getSocketFactory();
            SSLSocket sslSocket = (SSLSocket)factory.createSocket(plainSocket, address.getHostName(), address.getPort(), true);
            sslSocket.setEnabledProtocols(new String[]{"TLSv1.3"});
            sslSocket.setEnabledCipherSuites(new String[]{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256"});
            sslSocket.startHandshake();
            Certificate[] certs = sslSocket.getSession().getPeerCertificates();
            if (certs == null || certs.length == 0 || certs.length > 3) {
                sslSocket.close();
                throw new IOException("Invalid server certificate chain");
            }
            boolean validated = false;
            for (Certificate cert : certs) {
                X509Certificate x509Cert;
                String fingerprint;
                if (!(cert instanceof X509Certificate) || !(fingerprint = NetUtils.getFingerprint(x509Cert = (X509Certificate)cert, secret.secret())).equals(secret.fingerprint())) continue;
                validated = true;
                break;
            }
            if (!validated) {
                sslSocket.close();
                throw new IOException("Server certificate validation failed");
            }
            this.useCompression = true;
            this.secretBytes = Base64.getUrlDecoder().decode(secret.secret());
            this.socket = sslSocket;
            this.in = new DataInputStream(sslSocket.getInputStream());
            this.out = new DataOutputStream(sslSocket.getOutputStream());
            GlobalVariables.LOGGER.debug("Connection established with: {}", (Object)address.getHostString());
        }
        catch (Exception e) {
            throw new IOException("Failed to establish connection", e);
        }
    }

    public boolean isBusy() {
        return this.busy.get();
    }

    public void setBusy(boolean value) {
        this.busy.set(value);
    }

    public CompletableFuture<Path> sendDownloadFile(byte[] fileHash, Path destination, IntCallback chunkCallback) {
        if (destination == null) {
            throw new IllegalArgumentException("Destination cannot be null");
        }
        return CompletableFuture.supplyAsync(() -> {
            Exception exception = null;
            try {
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                DataOutputStream dos = new DataOutputStream(baos);
                dos.writeByte(1);
                dos.writeByte(1);
                dos.write(this.secretBytes);
                dos.writeInt(fileHash.length);
                dos.write(fileHash);
                dos.flush();
                byte[] payload = baos.toByteArray();
                this.writeProtocolMessage(payload);
                Path path = this.readFileResponse(destination, chunkCallback);
                return path;
            }
            catch (Exception e) {
                exception = e;
                throw new CompletionException(e);
            }
            finally {
                this.finalBlock(exception);
            }
        }, this.executor);
    }

    public CompletableFuture<Path> sendRefreshRequest(byte[][] fileHashes, Path destination) {
        return CompletableFuture.supplyAsync(() -> {
            Exception exception = null;
            try {
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                DataOutputStream dos = new DataOutputStream(baos);
                dos.writeByte(1);
                dos.writeByte(3);
                dos.write(this.secretBytes);
                dos.writeInt(fileHashes.length);
                if (fileHashes.length > 0) {
                    dos.writeInt(fileHashes[0].length);
                    for (byte[] hash : fileHashes) {
                        dos.write(hash);
                    }
                }
                dos.flush();
                byte[] payload = baos.toByteArray();
                this.writeProtocolMessage(payload);
                Path path = this.readFileResponse(destination, null);
                return path;
            }
            catch (Exception e) {
                exception = e;
                throw new CompletionException(e);
            }
            finally {
                this.finalBlock(exception);
            }
        }, this.executor);
    }

    private void finalBlock(Exception exception) {
        try {
            while (this.in.available() > 0) {
                this.in.skipBytes(this.in.available());
            }
        }
        catch (IOException e) {
            if (exception == null) {
                exception = e;
                throw new CompletionException(e);
            }
        }
        finally {
            if (exception == null) {
                this.setBusy(false);
            }
        }
    }

    private void writeProtocolMessage(byte[] payload) throws IOException {
        if (!this.useCompression) {
            this.out.writeInt(payload.length);
            this.out.write(payload);
        } else {
            byte[] compressed = Zstd.compress(payload);
            this.out.writeInt(compressed.length);
            this.out.writeInt(payload.length);
            this.out.write(compressed);
        }
        this.out.flush();
    }

    private byte[] readProtocolMessageFrame() throws IOException {
        if (!this.useCompression) {
            int origLength = this.in.readInt();
            byte[] data = new byte[origLength];
            this.in.readFully(data);
            return data;
        }
        int compLength = this.in.readInt();
        int origLength = this.in.readInt();
        byte[] compData = new byte[compLength];
        this.in.readFully(compData);
        byte[] decompressed = Zstd.decompress(compData, origLength);
        if (decompressed.length != origLength) {
            throw new IOException("Decompressed length does not match original length");
        }
        return decompressed;
    }

    private Path readFileResponse(Path destination, IntCallback chunkCallback) throws IOException {
        byte[] headerFrame = this.readProtocolMessageFrame();
        try (DataInputStream headerIn = new DataInputStream(new ByteArrayInputStream(headerFrame));){
            int toWrite;
            byte version = headerIn.readByte();
            byte messageType = headerIn.readByte();
            if (messageType == 5) {
                int errLen = headerIn.readInt();
                byte[] errBytes = new byte[errLen];
                headerIn.readFully(errBytes);
                throw new IOException("Server error: " + new String(errBytes));
            }
            FileOutputStream fos = new FileOutputStream(destination.toFile());
            if (messageType == 4) {
                ((OutputStream)fos).close();
                Path path = destination;
                return path;
            }
            if (messageType != 2) {
                ((OutputStream)fos).close();
                throw new IOException("Unexpected message type: " + messageType);
            }
            long expectedFileSize = headerIn.readLong();
            for (long receivedBytes = 0L; receivedBytes < expectedFileSize; receivedBytes += (long)toWrite) {
                byte[] dataFrame = this.readProtocolMessageFrame();
                toWrite = Math.min(dataFrame.length, (int)(expectedFileSize - receivedBytes));
                ((OutputStream)fos).write(dataFrame, 0, toWrite);
                if (chunkCallback == null) continue;
                chunkCallback.run(toWrite);
            }
            ((OutputStream)fos).close();
            byte[] eotFrame = this.readProtocolMessageFrame();
            try (DataInputStream eotIn = new DataInputStream(new ByteArrayInputStream(eotFrame));){
                byte ver = eotIn.readByte();
                byte eotType = eotIn.readByte();
                if (ver != version || eotType != 4) {
                    throw new IOException("Invalid end-of-transmission marker. Expected version " + version + " and type 4, got version " + ver + " and type " + eotType);
                }
            }
            Path path = destination;
            return path;
        }
    }

    public void close() {
        try {
            this.socket.close();
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.executor.shutdownNow();
    }

    private SSLContext createSSLContext() throws Exception {
        SSLContext sslContext = SSLContext.getInstance("TLSv1.3");
        TrustManager[] trustAllCerts = new TrustManager[]{new X509TrustManager(){

            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return new X509Certificate[0];
            }

            @Override
            public void checkClientTrusted(X509Certificate[] certs, String authType) {
            }

            @Override
            public void checkServerTrusted(X509Certificate[] certs, String authType) {
            }
        }};
        sslContext.init(null, trustAllCerts, new SecureRandom());
        return sslContext;
    }
}

