/*
 * Decompiled with CFR 0.152.
 */
package com.phono.srtplight;

import com.phono.srtplight.Log;
import com.phono.srtplight.RTPDataSink;
import com.phono.srtplight.RTPPacketException;
import com.phono.srtplight.RTPProtocolImpl;
import com.phono.srtplight.SRTPSecContext;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Properties;
import javax.crypto.Mac;

public class SRTPProtocolImpl
extends RTPProtocolImpl {
    static final int SRTPWINDOWSIZE = 64;
    long[] _replay = new long[64];
    long _windowLeadingEdge = 0L;
    private SRTPSecContext _scIn;
    private SRTPSecContext _scOut;
    private boolean _doCrypt = true;
    private boolean _doAuth = true;
    Character oseq = null;
    int roc = 0;
    static final int wrapdiff = 32768;

    private void init(Properties lcryptoProps, Properties rcryptoProps) {
        this._srtp = true;
        this._scOut = null;
        this._scIn = null;
        try {
            if (this._doAuth || this._doCrypt) {
                this._scIn = new SRTPSecContext(true);
                this._scIn.parseCryptoProps(rcryptoProps);
                this._tailIn = this._scIn.getAuthTail();
                this._scOut = new SRTPSecContext(false);
                this._scOut.parseCryptoProps(lcryptoProps);
                this._tailOut = this._scOut.getAuthTail();
            }
        }
        catch (GeneralSecurityException ex) {
            Log.error(" error in constructor " + ex.getMessage());
            ex.printStackTrace();
        }
    }

    public SRTPProtocolImpl(int id, DatagramSocket ds, InetSocketAddress far, int type, Properties lcryptoProps, Properties rcryptoProps) {
        super(id, ds, far, type);
        this.init(lcryptoProps, rcryptoProps);
    }

    public SRTPProtocolImpl(int id, String local_media_address, int local_audio_port, String remote_media_address, int remote_audio_port, int type, Properties lcryptoProps, Properties rcryptoProps) throws SocketException {
        super(id, local_media_address, local_audio_port, remote_media_address, remote_audio_port, type);
        this.init(lcryptoProps, rcryptoProps);
    }

    void checkForReplay() throws RTPPacketException {
        if (this._index < this._windowLeadingEdge) {
            if (this._windowLeadingEdge - this._index > 64L) {
                throw new RTPPacketException(" out of window, packet too old");
            }
            int tidx = (int)(this._index % 64L);
            if (this._replay[tidx] == this._index) {
                throw new RTPPacketException(" Seen that packet before - replay attack ? " + this._index);
            }
        }
    }

    @Override
    void checkAuth(byte[] packet, int plen) throws RTPPacketException {
        if (Log.getLevel() > 4) {
            Log.verb("auth on packet " + SRTPProtocolImpl.getHex(packet, plen));
            Log.verb("Packet index " + Long.toHexString(this._index));
        }
        try {
            this._scIn.deriveKeys(0L);
            if (this._doAuth) {
                Mac hmac = this._scIn.getAuthMac();
                int alen = this._tailIn;
                int offs = plen - alen;
                ByteBuffer m = ByteBuffer.allocate(offs + 4);
                m.put(packet, 0, offs);
                m.putInt((int)this._roc);
                byte[] auth = new byte[alen];
                System.arraycopy(packet, offs, auth, 0, alen);
                int mlen = plen - 12 - alen;
                Log.verb("mess length =" + mlen);
                if (Log.getLevel() > 4) {
                    Log.verb("auth body " + SRTPProtocolImpl.getHex(m.array()));
                }
                ((Buffer)m).position(0);
                hmac.update(m);
                byte[] mac = hmac.doFinal();
                if (Log.getLevel() > 4) {
                    Log.verb("auth in   " + SRTPProtocolImpl.getHex(auth));
                }
                if (Log.getLevel() > 4) {
                    Log.verb("auth out  " + SRTPProtocolImpl.getHex(mac, 10));
                }
                for (int i = 0; i < alen; ++i) {
                    if (auth[i] == mac[i]) continue;
                    throw new RTPPacketException("not authorized byte " + i + " does not match ");
                }
            }
        }
        catch (GeneralSecurityException ex) {
            throw new RTPPacketException("Problem checking  packet " + ex.getMessage());
        }
    }

    @Override
    protected void deliverPayload(byte[] payload, long stamp, int ssrc, char seqno) {
        try {
            if (this._doCrypt) {
                this.decrypt(payload, ssrc);
            }
            super.deliverPayload(payload, stamp, ssrc, seqno);
        }
        catch (GeneralSecurityException ex) {
            Log.error("problem with decryption " + ex.getMessage());
        }
    }

    @Override
    void updateCounters(char seqno) {
        int tidx = (int)(this._index % 64L);
        this._replay[tidx] = this._index;
        if (this._index > this._windowLeadingEdge) {
            this._windowLeadingEdge = this._index;
        }
        super.updateCounters(seqno);
    }

    @Override
    void appendAuth(byte[] packet) throws RTPPacketException {
        if (this._doAuth) {
            try {
                Mac mac = this._scOut.getAuthMac();
                int offs = packet.length - this._tailOut;
                ByteBuffer m = ByteBuffer.allocate(offs + 4);
                m.put(packet, 0, offs);
                int oroc = (int)(this._seqno >>> 16);
                if ((this._seqno & 0xFFFFL) == 0L) {
                    Log.debug("seqno = 0 outgoing roc =" + oroc);
                }
                m.putInt(oroc);
                if (Log.getLevel() > 4) {
                    Log.verb("auth body " + SRTPProtocolImpl.getHex(m.array()));
                }
                ((Buffer)m).position(0);
                mac.update(m);
                byte[] auth = mac.doFinal();
                int len = this._tailOut;
                for (int i = 0; i < len; ++i) {
                    packet[offs + i] = auth[i];
                }
            }
            catch (GeneralSecurityException ex) {
                throw new RTPPacketException("Problem sending  packet " + ex.getMessage());
            }
        }
        if (Log.getLevel() > 4) {
            Log.verb("Sending packet " + SRTPProtocolImpl.getHex(packet));
        }
    }

    public void reSendEncryptedPacket(byte[] data, long stamp, long seqno, int ptype, boolean marker) throws SocketException, IOException {
        super.sendPacket(data, stamp, (char)seqno, ptype, marker);
    }

    public void reSendUnEncryptedPacket(byte[] data, long stamp, long seq, int ptype, boolean marker) throws SocketException, IOException {
        if (this._doCrypt) {
            try {
                this._scOut.deriveKeys(seq);
                this.encrypt(data, (int)this._csrcid, seq);
                super.sendPacket(data, stamp, (char)seq, ptype, marker);
            }
            catch (GeneralSecurityException ex) {
                Log.error("problem encrypting packet" + ex.getMessage());
                ex.printStackTrace();
            }
        }
    }

    @Override
    public void sendPacket(byte[] data, long stamp, char seqno, int ptype, boolean marker) throws SocketException, IOException {
        int diff;
        int n = this.roc;
        if (this.oseq == null) {
            this.oseq = Character.valueOf(seqno);
        }
        if ((diff = seqno - this.oseq.charValue()) < Short.MIN_VALUE) {
            ++this.roc;
            n = this.roc;
            Log.debug(" wrapped seqno " + seqno + " oseq " + this.oseq.charValue() + " diff =" + diff + " outgoing roc =" + n);
        }
        if (diff > 32768) {
            n = this.roc - 1;
            Log.debug(" unwrapped seqno " + seqno + " oseq " + this.oseq.charValue() + " diff =" + diff + " outgoing roc =" + n);
        }
        this.oseq = Character.valueOf(seqno);
        long low = seqno;
        long high = (long)n << 16;
        this._seqno = low | high;
        if (this._doCrypt) {
            try {
                this._scOut.deriveKeys(this._seqno);
                this.encrypt(data, (int)this._csrcid, this._seqno);
                super.sendPacket(data, stamp, (char)this._seqno, ptype, marker);
            }
            catch (GeneralSecurityException ex) {
                Log.error("problem encrypting packet" + ex.getMessage());
                ex.printStackTrace();
            }
        }
    }

    @Override
    public void sendPacket(byte[] data, long stamp, int ptype, boolean marker) throws SocketException, IOException {
        try {
            if (this._doCrypt) {
                this._scOut.deriveKeys(stamp);
                this.encrypt(data, (int)this._csrcid, this._seqno);
            }
            super.sendPacket(data, stamp, ptype, marker);
        }
        catch (GeneralSecurityException ex) {
            Log.error("problem encrypting packet" + ex.getMessage());
            ex.printStackTrace();
        }
    }

    static ByteBuffer getPepper(int ssrc, long idx) {
        ByteBuffer pepper = ByteBuffer.allocate(16);
        pepper.putInt(4, ssrc);
        long sindex = idx << 16;
        pepper.putLong(8, sindex);
        return pepper;
    }

    private void decrypt(byte[] payload, int ssrc) throws GeneralSecurityException {
        ByteBuffer in = ByteBuffer.wrap(payload);
        int pl = (payload.length / 32 + 2) * 32;
        ByteBuffer out = ByteBuffer.allocate(pl);
        ByteBuffer pepper = SRTPProtocolImpl.getPepper(ssrc, this._index);
        this._scIn.decipher(in, out, pepper);
        System.arraycopy(out.array(), 0, payload, 0, payload.length);
    }

    private void encrypt(byte[] payload, int ssrc, long idx) throws GeneralSecurityException {
        ByteBuffer in = ByteBuffer.wrap(payload);
        int pl = (payload.length / 32 + 2) * 32;
        ByteBuffer out = ByteBuffer.allocate(pl);
        ByteBuffer pepper = SRTPProtocolImpl.getPepper(ssrc, idx);
        this._scOut.decipher(in, out, pepper);
        System.arraycopy(out.array(), 0, payload, 0, payload.length);
    }

    public static void main(String[] args) {
        System.out.println("testing STRP ");
        int id = 99;
        String local_media_address = "127.0.0.1";
        int local_audio_port = 19000;
        String remote_media_address = "127.0.0.1";
        int remote_audio_port = 19001;
        int type = 1;
        Log.setLevel(5);
        String sdp = "required='1' \ncrypto-suite='AES_CM_128_HMAC_SHA1_80' \nkey-params='inline:d0RmdmcmVCspeEc3QGZiNWpVLFJhQX1cfHAwJSoj' \nsession-params='KDR=0' \ntag='1' \n";
        ByteArrayInputStream sr = new ByteArrayInputStream(sdp.getBytes());
        Properties cryptoProps = new Properties();
        try {
            cryptoProps.load(sr);
        }
        catch (IOException ex) {
            Log.error("invalid sdp props.");
        }
        SRTPProtocolImpl s2 = null;
        try {
            s2 = new SRTPProtocolImpl(id, local_media_address, local_audio_port, remote_media_address, remote_audio_port, type, cryptoProps, cryptoProps);
            s2.testSendSRTP();
        }
        catch (IOException ex) {
            Log.error(ex.getMessage());
        }
        try {
            Thread.sleep(60000L);
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        if (s2 != null) {
            s2.terminate();
        }
    }

    private void testRcvSRTP() {
        this._ds.close();
        try {
            this._ds = new DatagramSocket(19000);
            this._ds.setSoTimeout(1000);
        }
        catch (SocketException ex) {
            ex.printStackTrace();
        }
        Log.debug("to test rcv of SRTP packets run");
        Log.debug("./rtpw -a -e -k " + SRTPProtocolImpl.getHex(this._scIn._masterKey) + SRTPProtocolImpl.getHex(this._scIn._masterSalt.array()) + " -s 127.0.0.1 " + this._ds.getLocalPort());
        Log.debug("test srtp recv starting in 10 secs");
        try {
            Thread.sleep(10000L);
        }
        catch (InterruptedException ex) {
            // empty catch block
        }
        RTPDataSink sink = new RTPDataSink(){

            @Override
            public void dataPacketReceived(byte[] data, long stamp, long idx) {
                Log.debug("got " + data.length + " bytes");
                Log.debug("data =" + SRTPProtocolImpl.getHex(data));
                Log.debug("Message is " + new String(data));
            }
        };
        this.setRTPDataSink(sink);
        this.startrecv();
    }

    private void testSeqs() {
        try {
            char seq = '\u0000';
            System.out.println("Seq test ");
            long top = Integer.MAX_VALUE;
            for (long j = 0L; j < top; ++j) {
                long i = this.getIndex(seq);
                if (i != j) {
                    throw new RTPPacketException("sequence test failed " + i + " != " + j + " seq " + (short)seq);
                }
                this._index = i;
                this.checkForReplay();
                this.updateCounters(seq);
                seq = (char)(seq + '\u0001');
            }
        }
        catch (RTPPacketException ex) {
            Log.debug(ex.getMessage());
        }
    }

    public static String getHex(byte[] in) {
        return SRTPProtocolImpl.getHex(in, in.length);
    }

    public static String getHex(byte[] in, int len) {
        char[] cmap = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
        StringBuffer ret = new StringBuffer();
        int top = Math.min(in.length, len);
        for (int i = 0; i < top; ++i) {
            ret.append(cmap[0xF & in[i] >>> 4]);
            ret.append(cmap[in[i] & 0xF]);
        }
        return ret.toString();
    }

    private void testSendSRTP() throws IOException {
        this._ds.close();
        try {
            this._ds = new DatagramSocket(19000);
            this._ds.setSoTimeout(60000);
        }
        catch (SocketException ex) {
            ex.printStackTrace();
        }
        Log.debug("to test rcv of SRTP packets run");
        Log.debug("./rtpw -d  -a -e -k " + SRTPProtocolImpl.getHex(this._scIn._masterKey) + SRTPProtocolImpl.getHex(this._scIn._masterSalt.array(), 14) + " -r 127.0.0.1 19002");
        Log.debug("test srtp send starting in 10 secs");
        try {
            Thread.sleep(10000L);
        }
        catch (InterruptedException ex) {
            // empty catch block
        }
        RTPDataSink sink = new RTPDataSink(){

            @Override
            public void dataPacketReceived(byte[] data, long stamp, long idx) {
                Log.debug("got " + data.length + " bytes");
                Log.debug("data =" + SRTPProtocolImpl.getHex(data));
                Log.debug("Message is " + new String(data));
            }
        };
        this.setRTPDataSink(sink);
        this._csrcid = -559038737L;
        this.startrecv();
        byte[] data = new byte[33];
        long stamp = 0L;
        String[] messages = new String[]{"A", "a", "aa", "aal", "aalii", "aam", "Aani", "aardvark", "aardwolf", "Aaron"};
        for (int i = 0; i < messages.length; ++i) {
            byte[] mess = new byte[messages[i].length() + 2];
            System.arraycopy(messages[i].getBytes(), 0, mess, 0, messages[i].length());
            mess[messages[i].length()] = 10;
            Log.debug("Sending " + messages[i]);
            this.sendPacket(mess, stamp += 8000L, 1);
            try {
                Thread.sleep(1000L);
                continue;
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
        }
    }
}

