package me.xemor.chatguardian;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import java.io.IOException;
import java.util.HashMap;

@Singleton
/* loaded from: input_file:me/xemor/chatguardian/SentenceEmbeddingModel.class */
public class SentenceEmbeddingModel {
    private final ChatGuardianCommon chatGuardianCommon;
    private final HuggingFaceTokenizer tokenizer;
    private final OrtSession session;
    private final OrtEnvironment env;

    /* JADX WARN: Finally extract failed */
    @Inject
    public SentenceEmbeddingModel(ChatGuardianCommon chatGuardianCommon) throws IOException {
        this.chatGuardianCommon = chatGuardianCommon;
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        try {
            try {
                Thread.currentThread().setContextClassLoader(getClass().getClassLoader());
                this.tokenizer = HuggingFaceTokenizer.newInstance(chatGuardianCommon.dataDirectory().resolve("model"));
                this.env = OrtEnvironment.getEnvironment();
                this.session = this.env.createSession(chatGuardianCommon.dataDirectory().resolve("model").resolve("model.onnx").toAbsolutePath().toString(), new OrtSession.SessionOptions());
                Thread.currentThread().setContextClassLoader(contextClassLoader);
            } catch (OrtException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public FloatVector calculateEmbedding(String str) throws OrtException {
        Encoding encode = this.tokenizer.encode(str);
        HashMap hashMap = new HashMap();
        OnnxTensor createTensor = OnnxTensor.createTensor(this.env, new long[]{encode.getIds()});
        OnnxTensor createTensor2 = OnnxTensor.createTensor(this.env, new long[]{encode.getAttentionMask()});
        hashMap.put("input_ids", createTensor);
        hashMap.put("attention_mask", createTensor2);
        float[][] fArr = ((float[][][]) this.session.run(hashMap).get(0).getValue())[0];
        long[] attentionMask = encode.getAttentionMask();
        float[] fArr2 = new float[fArr[0].length];
        int i = 0;
        for (int i2 = 0; i2 < attentionMask.length; i2++) {
            if (attentionMask[i2] == 1) {
                for (int i3 = 0; i3 < fArr2.length; i3++) {
                    int i4 = i3;
                    fArr2[i4] = fArr2[i4] + fArr[i2][i3];
                }
                i++;
            }
        }
        for (int i5 = 0; i5 < fArr2.length; i5++) {
            int i6 = i5;
            fArr2[i6] = fArr2[i6] / i;
        }
        return new FloatVector(fArr2);
    }
}
