/*
 * Decompiled with CFR 0.152.
 */
package li.cil.oc2.jcodec.codecs.h264.encode;

import li.cil.oc2.jcodec.codecs.h264.encode.H264EncoderUtils;
import li.cil.oc2.jcodec.codecs.h264.encode.MBEncoderHelper;
import li.cil.oc2.jcodec.codecs.h264.io.model.SeqParameterSet;
import li.cil.oc2.jcodec.common.model.Picture;
import li.cil.oc2.jcodec.common.tools.MathUtil;

public final class MotionEstimator {
    private final int maxSearchRange;
    private final int[] mvTopX;
    private final int[] mvTopY;
    private final int[] mvTopR;
    private int mvLeftX;
    private int mvLeftY;
    private int mvLeftR;
    private int mvTopLeftX;
    private int mvTopLeftY;
    private int mvTopLeftR;
    private final SeqParameterSet sps;
    private final Picture ref;
    private static final int[] SUB_X_OFF = new int[]{0, -2, 2, 0, 0, -2, -2, 2, 2, -1, 1, 0, 0, -1, -2, -1, -2, 1, 2, 1, 2, -1, 1, -1, 1};
    private static final int[] SUB_Y_OFF = new int[]{0, 0, 0, -2, 2, -2, 2, -2, 2, 0, 0, -1, 1, -2, -1, 2, 1, -2, -1, 2, 1, -1, -1, 1, 1};

    public MotionEstimator(Picture ref, SeqParameterSet sps, int maxSearchRange) {
        this.sps = sps;
        this.ref = ref;
        this.mvTopX = new int[sps.picWidthInMbsMinus1 + 1];
        this.mvTopY = new int[sps.picWidthInMbsMinus1 + 1];
        this.mvTopR = new int[sps.picWidthInMbsMinus1 + 1];
        this.maxSearchRange = maxSearchRange;
    }

    public int[] mvEstimate(Picture pic, int mbX, int mbY) {
        boolean refIdx = true;
        byte[] patch = new byte[256];
        boolean trAvb = mbY > 0 && mbX < this.sps.picWidthInMbsMinus1;
        boolean tlAvb = mbX > 0 && mbY > 0;
        int ax = this.mvLeftX;
        int ay = this.mvLeftY;
        boolean ar = this.mvLeftR == 1;
        int bx = this.mvTopX[mbX];
        int by = this.mvTopY[mbX];
        boolean br = this.mvTopR[mbX] == 1;
        int cx = trAvb ? this.mvTopX[mbX + 1] : 0;
        int cy = trAvb ? this.mvTopY[mbX + 1] : 0;
        boolean cr = trAvb && this.mvTopR[mbX + 1] == 1;
        int dx = tlAvb ? this.mvTopLeftX : 0;
        int dy = tlAvb ? this.mvTopLeftY : 0;
        boolean dr = tlAvb && this.mvTopLeftR == 1;
        int mvpx = H264EncoderUtils.median(ax, ar, bx, br, cx, cr, dx, dr, mbX > 0, mbY > 0, trAvb, tlAvb);
        int mvpy = H264EncoderUtils.median(ay, ar, by, br, cy, cr, dy, dr, mbX > 0, mbY > 0, trAvb, tlAvb);
        MBEncoderHelper.take(pic.getPlaneData(0), pic.getPlaneWidth(0), pic.getPlaneHeight(0), mbX << 4, mbY << 4, patch, 16, 16);
        int[] fullPix = this.estimateFullPix(this.ref, patch, mbX, mbY, mvpx, mvpy);
        return MotionEstimator.estimateQPix(this.ref, patch, fullPix, mbX, mbY);
    }

    public static int[] estimateQPix(Picture ref, byte[] patch, int[] fullPix, int mbX, int mbY) {
        int i;
        int j;
        int fullX = (mbX << 4) + (fullPix[0] >> 2);
        int fullY = (mbY << 4) + (fullPix[1] >> 2);
        if (fullX < 3 || fullY < 3) {
            return fullPix;
        }
        byte[] sp = new byte[484];
        MBEncoderHelper.take(ref.getPlaneData(0), ref.getPlaneWidth(0), ref.getPlaneHeight(0), fullX - 3, fullY - 3, sp, 22, 22);
        int[] pp = new int[352];
        int[] pn = new int[352];
        int[] scores = new int[25];
        int dOff = 0;
        int sOff = 0;
        for (j = 0; j < 22; ++j) {
            i = 0;
            while (i < 16) {
                int a = sp[sOff] + sp[sOff + 5];
                int b = sp[sOff + 1] + sp[sOff + 4];
                int c = sp[sOff + 2] + sp[sOff + 3];
                pn[dOff] = a + 5 * ((c << 2) - b);
                a = sp[sOff + 1] + sp[sOff + 6];
                b = sp[sOff + 2] + sp[sOff + 5];
                c = sp[sOff + 3] + sp[sOff + 4];
                pp[dOff] = a + 5 * ((c << 2) - b);
                ++i;
                ++dOff;
                ++sOff;
            }
            sOff += 6;
        }
        int sof = 0;
        int off = 0;
        for (j = 0; j < 16; ++j) {
            i = 0;
            while (i < 16) {
                scores[0] = scores[0] + MathUtil.abs(patch[off] - sp[sof + 69]);
                int horN20 = MathUtil.clip(pn[off + 48] + 16 >> 5, -128, 127);
                int horP20 = MathUtil.clip(pp[off + 48] + 16 >> 5, -128, 127);
                scores[1] = scores[1] + MathUtil.abs(patch[off] - horN20);
                scores[2] = scores[2] + MathUtil.abs(patch[off] - horP20);
                int horN10 = horN20 + sp[sof + 69] + 1 >> 1;
                int horP10 = horP20 + sp[sof + 69] + 1 >> 1;
                scores[9] = scores[9] + MathUtil.abs(patch[off] - horN10);
                scores[10] = scores[10] + MathUtil.abs(patch[off] - horP10);
                int a = sp[3 + sof] + sp[3 + sof + 110];
                int b = sp[3 + sof + 22] + sp[3 + sof + 88];
                int c = sp[3 + sof + 44] + sp[3 + sof + 66];
                int verNeg = a + 5 * ((c << 2) - b);
                int verN20 = MathUtil.clip(verNeg + 16 >> 5, -128, 127);
                int verN10 = verN20 + sp[sof + 69] + 1 >> 1;
                int dnn = verN20 + horN20 + 1 >> 1;
                int dpn = verN20 + horP20 + 1 >> 1;
                scores[3] = scores[3] + MathUtil.abs(patch[off] - verN20);
                scores[11] = scores[11] + MathUtil.abs(patch[off] - verN10);
                scores[21] = scores[21] + MathUtil.abs(patch[off] - dnn);
                scores[22] = scores[22] + MathUtil.abs(patch[off] - dpn);
                int a2 = sp[3 + sof + 22] + sp[3 + sof + 132];
                int b2 = sp[3 + sof + 44] + sp[3 + sof + 110];
                int c2 = sp[3 + sof + 66] + sp[3 + sof + 88];
                int verPos = a2 + 5 * ((c2 << 2) - b2);
                int verP20 = MathUtil.clip(verPos + 16 >> 5, -128, 127);
                int verP10 = verP20 + sp[sof + 69] + 1 >> 1;
                int dnp = verP20 + horN20 + 1 >> 1;
                int dpp = verP20 + horP20 + 1 >> 1;
                scores[4] = scores[4] + MathUtil.abs(patch[off] - verP20);
                scores[12] = scores[12] + MathUtil.abs(patch[off] - verP10);
                scores[23] = scores[23] + MathUtil.abs(patch[off] - dnp);
                scores[24] = scores[24] + MathUtil.abs(patch[off] - dpp);
                a2 = pn[off] + pn[off + 80];
                b2 = pn[off + 16] + pn[off + 64];
                c2 = pn[off + 32] + pn[off + 48];
                int interpNeg = a2 + 5 * ((c2 << 2) - b2);
                int diagNN = MathUtil.clip(interpNeg + 512 >> 10, -128, 127);
                int ver = diagNN + verN20 + 1 >> 1;
                int hor = diagNN + horN20 + 1 >> 1;
                scores[5] = scores[5] + MathUtil.abs(patch[off] - diagNN);
                scores[13] = scores[13] + MathUtil.abs(patch[off] - ver);
                scores[14] = scores[14] + MathUtil.abs(patch[off] - hor);
                a2 = pn[off + 16] + pn[off + 96];
                b2 = pn[off + 32] + pn[off + 80];
                c2 = pn[off + 48] + pn[off + 64];
                int interpPos = a2 + 5 * ((c2 << 2) - b2);
                int diagNP = MathUtil.clip(interpPos + 512 >> 10, -128, 127);
                ver = diagNP + verP20 + 1 >> 1;
                hor = diagNP + horN20 + 1 >> 1;
                scores[6] = scores[6] + MathUtil.abs(patch[off] - diagNP);
                scores[15] = scores[15] + MathUtil.abs(patch[off] - ver);
                scores[16] = scores[16] + MathUtil.abs(patch[off] - hor);
                a2 = pp[off] + pp[off + 80];
                b2 = pp[off + 16] + pp[off + 64];
                c2 = pp[off + 32] + pp[off + 48];
                interpNeg = a2 + 5 * ((c2 << 2) - b2);
                int diagPN = MathUtil.clip(interpNeg + 512 >> 10, -128, 127);
                ver = diagPN + verN20 + 1 >> 1;
                hor = diagPN + horP20 + 1 >> 1;
                scores[7] = scores[7] + MathUtil.abs(patch[off] - diagPN);
                scores[17] = scores[17] + MathUtil.abs(patch[off] - ver);
                scores[18] = scores[18] + MathUtil.abs(patch[off] - hor);
                a2 = pp[off + 16] + pp[off + 96];
                b2 = pp[off + 32] + pp[off + 80];
                c2 = pp[off + 48] + pp[off + 64];
                interpPos = a2 + 5 * ((c2 << 2) - b2);
                int diagPP = MathUtil.clip(interpPos + 512 >> 10, -128, 127);
                ver = diagPP + verP20 + 1 >> 1;
                hor = diagPP + horP20 + 1 >> 1;
                scores[8] = scores[8] + MathUtil.abs(patch[off] - diagPP);
                scores[19] = scores[19] + MathUtil.abs(patch[off] - ver);
                scores[20] = scores[20] + MathUtil.abs(patch[off] - hor);
                ++i;
                ++off;
                ++sof;
            }
            sof += 6;
        }
        int m0 = Math.min(scores[1], scores[2]);
        int m1 = Math.min(scores[3], scores[4]);
        int m2 = Math.min(scores[5], scores[6]);
        int m3 = Math.min(scores[7], scores[8]);
        int m4 = Math.min(scores[9], scores[10]);
        int m5 = Math.min(scores[11], scores[12]);
        int m6 = Math.min(scores[13], scores[14]);
        int m7 = Math.min(scores[15], scores[16]);
        int m8 = Math.min(scores[17], scores[18]);
        int m9 = Math.min(scores[19], scores[20]);
        int m10 = Math.min(scores[21], scores[22]);
        int m11 = Math.min(scores[23], scores[24]);
        m0 = Math.min(m0, m1);
        m2 = Math.min(m2, m3);
        m4 = Math.min(m4, m5);
        m6 = Math.min(m6, m7);
        m8 = Math.min(m8, m9);
        m10 = Math.min(m10, m11);
        m0 = Math.min(m0, m2);
        m4 = Math.min(m4, m6);
        m8 = Math.min(m8, m10);
        int mf0 = Math.min(scores[0], m0);
        int mf1 = Math.min(m4, m8);
        int mf2 = Math.min(mf0, mf1);
        int sel = 0;
        for (int i2 = 0; i2 < 25; ++i2) {
            if (mf2 != scores[i2]) continue;
            sel = i2;
            break;
        }
        return new int[]{fullPix[0] + SUB_X_OFF[sel], fullPix[1] + SUB_Y_OFF[sel]};
    }

    public void mvSave(int mbX, int[] mv) {
        this.mvTopLeftX = this.mvTopX[mbX];
        this.mvTopLeftY = this.mvTopY[mbX];
        this.mvTopLeftR = this.mvTopR[mbX];
        this.mvTopX[mbX] = mv[0];
        this.mvTopY[mbX] = mv[1];
        this.mvTopR[mbX] = mv[2];
        this.mvLeftX = mv[0];
        this.mvLeftY = mv[1];
        this.mvLeftR = mv[2];
    }

    private int[] estimateFullPix(Picture ref, byte[] patch, int mbX, int mbY, int mvpx, int mvpy) {
        byte[] searchPatch = new byte[(this.maxSearchRange * 2 + 16) * (this.maxSearchRange * 2 + 16)];
        int mvX0 = 0;
        int mvX1 = 0;
        int mvY0 = 0;
        int mvY1 = 0;
        int mvS0 = Integer.MAX_VALUE;
        int mvS1 = Integer.MAX_VALUE;
        int startX = mbX << 4;
        int startY = mbY << 4;
        for (int area = 0; area < 2; ++area) {
            int patchTlX = Math.max(startX - this.maxSearchRange, 0);
            int patchTlY = Math.max(startY - this.maxSearchRange, 0);
            int patchBrX = Math.min(startX + this.maxSearchRange + 16, ref.getPlaneWidth(0));
            int patchBrY = Math.min(startY + this.maxSearchRange + 16, ref.getPlaneHeight(0));
            int inPatchX = startX - patchTlX;
            int inPatchY = startY - patchTlY;
            if (inPatchX < 0 || inPatchY < 0) continue;
            int patchW = patchBrX - patchTlX;
            int patchH = patchBrY - patchTlY;
            MBEncoderHelper.takeSafe(ref.getPlaneData(0), ref.getPlaneWidth(0), patchTlX, patchTlY, searchPatch, patchW, patchH);
            int bestMvX = inPatchX;
            int bestMvY = inPatchY;
            int bestScore = this.sad(searchPatch, patchW, patch, bestMvX, bestMvY);
            for (int i = 0; i < this.maxSearchRange; ++i) {
                int score1 = bestMvX > 0 ? this.sad(searchPatch, patchW, patch, bestMvX - 1, bestMvY) : Integer.MAX_VALUE;
                int score2 = bestMvX < patchW - 1 ? this.sad(searchPatch, patchW, patch, bestMvX + 1, bestMvY) : Integer.MAX_VALUE;
                int score3 = bestMvY > 0 ? this.sad(searchPatch, patchW, patch, bestMvX, bestMvY - 1) : Integer.MAX_VALUE;
                int score4 = bestMvY < patchH - 1 ? this.sad(searchPatch, patchW, patch, bestMvX, bestMvY + 1) : Integer.MAX_VALUE;
                int min = Math.min(Math.min(Math.min(score1, score2), score3), score4);
                if (min > bestScore) break;
                bestScore = min;
                if (score1 == min) {
                    --bestMvX;
                    continue;
                }
                if (score2 == min) {
                    ++bestMvX;
                    continue;
                }
                if (score3 == min) {
                    --bestMvY;
                    continue;
                }
                ++bestMvY;
            }
            if (area == 0) {
                mvX0 = bestMvX - inPatchX << 2;
                mvY0 = bestMvY - inPatchY << 2;
                mvS0 = bestScore;
                startX = (mbX << 4) + (mvpx >> 2);
                startY = (mbY << 4) + (mvpy >> 2);
                continue;
            }
            mvX1 = bestMvX - inPatchX + startX - (mbX << 4) << 2;
            mvY1 = bestMvY - inPatchY + startY - (mbY << 4) << 2;
            mvS1 = bestScore;
        }
        return new int[]{mvS0 < mvS1 ? mvX0 : mvX1, mvS0 < mvS1 ? mvY0 : mvY1};
    }

    private int sad(byte[] big, int bigStride, byte[] small, int offX, int offY) {
        int score = 0;
        int bigOff = offY * bigStride + offX;
        int smallOff = 0;
        for (int i = 0; i < 16; ++i) {
            int j = 0;
            while (j < 16) {
                score += MathUtil.abs(big[bigOff] - small[smallOff]);
                ++j;
                ++bigOff;
                ++smallOff;
            }
            bigOff += bigStride - 16;
        }
        return score;
    }
}

