/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.internal.connection;

import com.mongodb.AuthenticationMechanism;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.ServerApi;
import com.mongodb.assertions.Assertions;
import com.mongodb.connection.ClusterConnectionMode;
import com.mongodb.internal.authentication.NativeAuthenticationHelper;
import com.mongodb.internal.authentication.SaslPrep;
import com.mongodb.internal.connection.InternalConnection;
import com.mongodb.internal.connection.MongoCredentialWithCache;
import com.mongodb.internal.connection.SaslAuthenticator;
import com.mongodb.lang.Nullable;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.util.Base64;
import java.util.HashMap;
import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.bson.BsonBoolean;
import org.bson.BsonDocument;
import org.bson.BsonString;

class ScramShaAuthenticator
extends SaslAuthenticator {
    private final RandomStringGenerator randomStringGenerator;
    private final AuthenticationHashGenerator authenticationHashGenerator;
    private SaslClient speculativeSaslClient;
    private BsonDocument speculativeAuthenticateResponse;
    private static final int MINIMUM_ITERATION_COUNT = 4096;
    private static final String GS2_HEADER = "n,,";
    private static final int RANDOM_LENGTH = 24;
    private static final AuthenticationHashGenerator DEFAULT_AUTHENTICATION_HASH_GENERATOR = mongoCredential -> {
        char[] cArray = mongoCredential.getPassword();
        if (cArray == null) {
            throw new IllegalArgumentException("Password must not be null");
        }
        return new String(cArray);
    };
    private static final AuthenticationHashGenerator LEGACY_AUTHENTICATION_HASH_GENERATOR = mongoCredential -> {
        String string = mongoCredential.getUserName();
        char[] cArray = mongoCredential.getPassword();
        if (string == null || cArray == null) {
            throw new IllegalArgumentException("Username and password must not be null");
        }
        return NativeAuthenticationHelper.createAuthenticationHash(string, cArray);
    };

    ScramShaAuthenticator(MongoCredentialWithCache mongoCredentialWithCache, ClusterConnectionMode clusterConnectionMode, @Nullable ServerApi serverApi) {
        this(mongoCredentialWithCache, new DefaultRandomStringGenerator(), ScramShaAuthenticator.getAuthenicationHashGenerator(Assertions.assertNotNull(mongoCredentialWithCache.getAuthenticationMechanism())), clusterConnectionMode, serverApi);
    }

    ScramShaAuthenticator(MongoCredentialWithCache mongoCredentialWithCache, RandomStringGenerator randomStringGenerator, AuthenticationHashGenerator authenticationHashGenerator, ClusterConnectionMode clusterConnectionMode, @Nullable ServerApi serverApi) {
        super(mongoCredentialWithCache, clusterConnectionMode, serverApi);
        this.randomStringGenerator = randomStringGenerator;
        this.authenticationHashGenerator = authenticationHashGenerator;
    }

    @Override
    public String getMechanismName() {
        AuthenticationMechanism authenticationMechanism = this.getMongoCredential().getAuthenticationMechanism();
        if (authenticationMechanism == null) {
            throw new IllegalArgumentException("Authentication mechanism cannot be null");
        }
        return authenticationMechanism.getMechanismName();
    }

    @Override
    protected void appendSaslStartOptions(BsonDocument bsonDocument) {
        bsonDocument.append("options", new BsonDocument("skipEmptyExchange", new BsonBoolean(true)));
    }

    @Override
    protected SaslClient createSaslClient(ServerAddress serverAddress) {
        if (this.speculativeSaslClient != null) {
            return this.speculativeSaslClient;
        }
        return new ScramShaSaslClient(this.getMongoCredentialWithCache().getCredential(), this.randomStringGenerator, this.authenticationHashGenerator);
    }

    @Override
    public BsonDocument createSpeculativeAuthenticateCommand(InternalConnection internalConnection) {
        try {
            this.speculativeSaslClient = this.createSaslClient(internalConnection.getDescription().getServerAddress());
            BsonDocument bsonDocument = this.createSaslStartCommandDocument(this.speculativeSaslClient.evaluateChallenge(new byte[0])).append("db", new BsonString(this.getMongoCredential().getSource()));
            this.appendSaslStartOptions(bsonDocument);
            return bsonDocument;
        }
        catch (Exception exception) {
            throw this.wrapException(exception);
        }
    }

    @Override
    public BsonDocument getSpeculativeAuthenticateResponse() {
        return this.speculativeAuthenticateResponse;
    }

    @Override
    public void setSpeculativeAuthenticateResponse(@Nullable BsonDocument bsonDocument) {
        if (bsonDocument == null) {
            this.speculativeSaslClient = null;
        } else {
            this.speculativeAuthenticateResponse = bsonDocument;
        }
    }

    private static AuthenticationHashGenerator getAuthenicationHashGenerator(AuthenticationMechanism authenticationMechanism) {
        return authenticationMechanism == AuthenticationMechanism.SCRAM_SHA_1 ? LEGACY_AUTHENTICATION_HASH_GENERATOR : DEFAULT_AUTHENTICATION_HASH_GENERATOR;
    }

    private static class DefaultRandomStringGenerator
    implements RandomStringGenerator {
        private DefaultRandomStringGenerator() {
        }

        @Override
        public String generate(int n) {
            SecureRandom secureRandom = new SecureRandom();
            int n2 = 44;
            int n3 = 33;
            int n4 = 126;
            int n5 = n4 - n3;
            char[] cArray = new char[n];
            for (int i = 0; i < n; ++i) {
                int n6 = secureRandom.nextInt(n5) + n3;
                while (n6 == n2) {
                    n6 = secureRandom.nextInt(n5) + n3;
                }
                cArray[i] = (char)n6;
            }
            return new String(cArray);
        }
    }

    public static interface AuthenticationHashGenerator {
        public String generate(MongoCredential var1);
    }

    public static interface RandomStringGenerator {
        public String generate(int var1);
    }

    class ScramShaSaslClient
    extends SaslAuthenticator.SaslClientImpl {
        private final RandomStringGenerator randomStringGenerator;
        private final AuthenticationHashGenerator authenticationHashGenerator;
        private final String hAlgorithm;
        private final String hmacAlgorithm;
        private final String pbeAlgorithm;
        private final int keyLength;
        private String clientFirstMessageBare;
        private String clientNonce;
        private byte[] serverSignature;
        private int step;

        ScramShaSaslClient(MongoCredential mongoCredential, RandomStringGenerator randomStringGenerator, AuthenticationHashGenerator authenticationHashGenerator) {
            super(mongoCredential);
            this.step = -1;
            this.randomStringGenerator = randomStringGenerator;
            this.authenticationHashGenerator = authenticationHashGenerator;
            if (Assertions.assertNotNull(mongoCredential.getAuthenticationMechanism()).equals((Object)AuthenticationMechanism.SCRAM_SHA_1)) {
                this.hAlgorithm = "SHA-1";
                this.hmacAlgorithm = "HmacSHA1";
                this.pbeAlgorithm = "PBKDF2WithHmacSHA1";
                this.keyLength = 160;
            } else {
                this.hAlgorithm = "SHA-256";
                this.hmacAlgorithm = "HmacSHA256";
                this.pbeAlgorithm = "PBKDF2WithHmacSHA256";
                this.keyLength = 256;
            }
        }

        @Override
        public byte[] evaluateChallenge(byte[] byArray) throws SaslException {
            ++this.step;
            if (this.step == 0) {
                return this.computeClientFirstMessage();
            }
            if (this.step == 1) {
                return this.computeClientFinalMessage(byArray);
            }
            if (this.step == 2) {
                return this.validateServerSignature(byArray);
            }
            throw new SaslException(String.format("Too many steps involved in the %s negotiation.", super.getMechanismName()));
        }

        private byte[] validateServerSignature(byte[] byArray) throws SaslException {
            String string = new String(byArray, StandardCharsets.UTF_8);
            HashMap<String, String> hashMap = this.parseServerResponse(string);
            if (!MessageDigest.isEqual(Base64.getDecoder().decode(hashMap.get("v")), this.serverSignature)) {
                throw new SaslException("Server signature was invalid.");
            }
            return new byte[0];
        }

        @Override
        public boolean isComplete() {
            return this.step == 2;
        }

        private byte[] computeClientFirstMessage() {
            String string;
            this.clientNonce = this.randomStringGenerator.generate(24);
            this.clientFirstMessageBare = string = "n=" + this.getUserName() + ",r=" + this.clientNonce;
            return (ScramShaAuthenticator.GS2_HEADER + string).getBytes(StandardCharsets.UTF_8);
        }

        private byte[] computeClientFinalMessage(byte[] byArray) throws SaslException {
            String string = new String(byArray, StandardCharsets.UTF_8);
            HashMap<String, String> hashMap = this.parseServerResponse(string);
            String string2 = hashMap.get("r");
            if (!string2.startsWith(this.clientNonce)) {
                throw new SaslException("Server sent an invalid nonce.");
            }
            String string3 = hashMap.get("s");
            int n = Integer.parseInt(hashMap.get("i"));
            if (n < 4096) {
                throw new SaslException("Invalid iteration count.");
            }
            String string4 = "c=" + Base64.getEncoder().encodeToString(ScramShaAuthenticator.GS2_HEADER.getBytes(StandardCharsets.UTF_8)) + ",r=" + string2;
            String string5 = this.clientFirstMessageBare + "," + string + "," + string4;
            String string6 = string4 + ",p=" + this.getClientProof(this.getAuthenicationHash(), string3, n, string5);
            return string6.getBytes(StandardCharsets.UTF_8);
        }

        String getClientProof(String string, String string2, int n, String string3) throws SaslException {
            byte[] byArray;
            byte[] byArray2;
            byte[] byArray3;
            String string4 = new String(this.h((string + string2).getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8);
            CacheKey cacheKey = new CacheKey(string4, string2, n);
            CacheValue cacheValue = ScramShaAuthenticator.this.getMongoCredentialWithCache().getFromCache(cacheKey, CacheValue.class);
            if (cacheValue == null) {
                byArray3 = this.hi(string, Base64.getDecoder().decode(string2), n);
                byArray2 = this.hmac(byArray3, "Client Key");
                byArray = this.hmac(byArray3, "Server Key");
                cacheValue = new CacheValue(byArray2, byArray);
                ScramShaAuthenticator.this.getMongoCredentialWithCache().putInCache(cacheKey, new CacheValue(byArray2, byArray));
            }
            this.serverSignature = this.hmac(cacheValue.serverKey, string3);
            byArray3 = this.h(cacheValue.clientKey);
            byArray2 = this.hmac(byArray3, string3);
            byArray = this.xor(cacheValue.clientKey, byArray2);
            return Base64.getEncoder().encodeToString(byArray);
        }

        private byte[] h(byte[] byArray) throws SaslException {
            try {
                return MessageDigest.getInstance(this.hAlgorithm).digest(byArray);
            }
            catch (NoSuchAlgorithmException noSuchAlgorithmException) {
                throw new SaslException(String.format("Algorithm for '%s' could not be found.", this.hAlgorithm), noSuchAlgorithmException);
            }
        }

        private byte[] hi(String string, byte[] byArray, int n) throws SaslException {
            try {
                SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance(this.pbeAlgorithm);
                PBEKeySpec pBEKeySpec = new PBEKeySpec(string.toCharArray(), byArray, n, this.keyLength);
                return secretKeyFactory.generateSecret(pBEKeySpec).getEncoded();
            }
            catch (NoSuchAlgorithmException noSuchAlgorithmException) {
                throw new SaslException(String.format("Algorithm for '%s' could not be found.", this.pbeAlgorithm), noSuchAlgorithmException);
            }
            catch (InvalidKeySpecException invalidKeySpecException) {
                throw new SaslException(String.format("Invalid key specification for '%s'", this.pbeAlgorithm), invalidKeySpecException);
            }
        }

        private byte[] hmac(byte[] byArray, String string) throws SaslException {
            try {
                Mac mac = Mac.getInstance(this.hmacAlgorithm);
                mac.init(new SecretKeySpec(byArray, this.hmacAlgorithm));
                return mac.doFinal(string.getBytes(StandardCharsets.UTF_8));
            }
            catch (NoSuchAlgorithmException noSuchAlgorithmException) {
                throw new SaslException(String.format("Algorithm for '%s' could not be found.", this.hmacAlgorithm), noSuchAlgorithmException);
            }
            catch (InvalidKeyException invalidKeyException) {
                throw new SaslException("Could not initialize mac.", invalidKeyException);
            }
        }

        private HashMap<String, String> parseServerResponse(String string) {
            String[] stringArray;
            HashMap<String, String> hashMap = new HashMap<String, String>();
            for (String string2 : stringArray = string.split(",")) {
                String[] stringArray2 = string2.split("=", 2);
                hashMap.put(stringArray2[0], stringArray2[1]);
            }
            return hashMap;
        }

        private String getUserName() {
            String string = this.getCredential().getUserName();
            if (string == null) {
                throw new IllegalArgumentException("Username can not be null");
            }
            return string.replace("=", "=3D").replace(",", "=2C");
        }

        private String getAuthenicationHash() {
            String string = this.authenticationHashGenerator.generate(this.getCredential());
            if (this.getCredential().getAuthenticationMechanism() == AuthenticationMechanism.SCRAM_SHA_256) {
                string = SaslPrep.saslPrepStored(string);
            }
            return string;
        }

        private byte[] xorInPlace(byte[] byArray, byte[] byArray2) {
            for (int i = 0; i < byArray.length; ++i) {
                int n = i;
                byArray[n] = (byte)(byArray[n] ^ byArray2[i]);
            }
            return byArray;
        }

        private byte[] xor(byte[] byArray, byte[] byArray2) {
            byte[] byArray3 = new byte[byArray.length];
            System.arraycopy(byArray, 0, byArray3, 0, byArray.length);
            return this.xorInPlace(byArray3, byArray2);
        }
    }

    private static class CacheValue {
        private final byte[] clientKey;
        private final byte[] serverKey;

        CacheValue(byte[] byArray, byte[] byArray2) {
            this.clientKey = byArray;
            this.serverKey = byArray2;
        }
    }

    private static class CacheKey {
        private final String hashedPasswordAndSalt;
        private final String salt;
        private final int iterationCount;

        CacheKey(String string, String string2, int n) {
            this.hashedPasswordAndSalt = string;
            this.salt = string2;
            this.iterationCount = n;
        }

        public boolean equals(Object object) {
            if (this == object) {
                return true;
            }
            if (object == null || this.getClass() != object.getClass()) {
                return false;
            }
            CacheKey cacheKey = (CacheKey)object;
            if (this.iterationCount != cacheKey.iterationCount) {
                return false;
            }
            if (!this.hashedPasswordAndSalt.equals(cacheKey.hashedPasswordAndSalt)) {
                return false;
            }
            return this.salt.equals(cacheKey.salt);
        }

        public int hashCode() {
            int n = this.hashedPasswordAndSalt.hashCode();
            n = 31 * n + this.salt.hashCode();
            n = 31 * n + this.iterationCount;
            return n;
        }
    }
}

