package dev.langchain4j.model.embedding.onnx;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.TokenCountEstimator;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:META-INF/jars/langchain4j-embeddings-1.0.0-beta5.jar:dev/langchain4j/model/embedding/onnx/HuggingFaceTokenCountEstimator.class */
public class HuggingFaceTokenCountEstimator implements TokenCountEstimator {
    private final HuggingFaceTokenizer tokenizer;

    public HuggingFaceTokenCountEstimator() {
        HashMap hashMap = new HashMap();
        hashMap.put("padding", "false");
        hashMap.put("truncation", "false");
        this.tokenizer = createFrom(getClass().getResourceAsStream("/bert-tokenizer.json"), hashMap);
    }

    public HuggingFaceTokenCountEstimator(Path path) {
        this(path, (Map<String, String>) null);
    }

    public HuggingFaceTokenCountEstimator(Path path, Map<String, String> map) {
        try {
            this.tokenizer = createFrom(Files.newInputStream(path, new OpenOption[0]), map);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public HuggingFaceTokenCountEstimator(String str) {
        this(str, (Map<String, String>) null);
    }

    public HuggingFaceTokenCountEstimator(String str, Map<String, String> map) {
        try {
            this.tokenizer = createFrom(Files.newInputStream(Paths.get(str, new String[0]), new OpenOption[0]), map);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static HuggingFaceTokenizer createFrom(InputStream inputStream, Map<String, String> map) {
        try {
            return HuggingFaceTokenizer.newInstance(inputStream, map);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // dev.langchain4j.model.TokenCountEstimator
    public int estimateTokenCountInText(String str) {
        return this.tokenizer.encode(str, false, true).getTokens().length;
    }

    @Override // dev.langchain4j.model.TokenCountEstimator
    public int estimateTokenCountInMessage(ChatMessage chatMessage) {
        if (chatMessage instanceof SystemMessage) {
            return estimateTokenCountInText(((SystemMessage) chatMessage).text());
        }
        if (chatMessage instanceof UserMessage) {
            return estimateTokenCountInText(((UserMessage) chatMessage).singleText());
        }
        if (chatMessage instanceof AiMessage) {
            return estimateTokenCountInText(((AiMessage) chatMessage).text());
        }
        if (chatMessage instanceof ToolExecutionResultMessage) {
            return estimateTokenCountInText(((ToolExecutionResultMessage) chatMessage).text());
        }
        throw new IllegalArgumentException("Unknown message type: " + String.valueOf(chatMessage));
    }

    @Override // dev.langchain4j.model.TokenCountEstimator
    public int estimateTokenCountInMessages(Iterable<ChatMessage> iterable) {
        int i = 0;
        Iterator<ChatMessage> it = iterable.iterator();
        while (it.hasNext()) {
            i += estimateTokenCountInMessage(it.next());
        }
        return i;
    }
}
