package dev.langchain4j.rag.query.router;

import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/rag/query/router/LanguageModelQueryRouter.class */
public class LanguageModelQueryRouter implements QueryRouter {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) LanguageModelQueryRouter.class);
    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("Based on the user query, determine the most suitable data source(s) to retrieve relevant information from the following options:\n{{options}}\nIt is very important that your answer consists of either a single number or multiple numbers separated by commas and nothing else!\nUser query: {{query}}");
    protected final ChatLanguageModel chatLanguageModel;
    protected final PromptTemplate promptTemplate;
    protected final String options;
    protected final Map<Integer, ContentRetriever> idToRetriever;
    protected final FallbackStrategy fallbackStrategy;

    /* loaded from: input_file:dev/langchain4j/rag/query/router/LanguageModelQueryRouter$FallbackStrategy.class */
    public enum FallbackStrategy {
        DO_NOT_ROUTE,
        ROUTE_TO_ALL,
        FAIL
    }

    /* loaded from: input_file:dev/langchain4j/rag/query/router/LanguageModelQueryRouter$LanguageModelQueryRouterBuilder.class */
    public static class LanguageModelQueryRouterBuilder {
        private ChatLanguageModel chatLanguageModel;
        private Map<ContentRetriever, String> retrieverToDescription;
        private PromptTemplate promptTemplate;
        private FallbackStrategy fallbackStrategy;

        LanguageModelQueryRouterBuilder() {
        }

        public LanguageModelQueryRouterBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        public LanguageModelQueryRouterBuilder retrieverToDescription(Map<ContentRetriever, String> map) {
            this.retrieverToDescription = map;
            return this;
        }

        public LanguageModelQueryRouterBuilder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        public LanguageModelQueryRouterBuilder fallbackStrategy(FallbackStrategy fallbackStrategy) {
            this.fallbackStrategy = fallbackStrategy;
            return this;
        }

        public LanguageModelQueryRouter build() {
            return new LanguageModelQueryRouter(this.chatLanguageModel, this.retrieverToDescription, this.promptTemplate, this.fallbackStrategy);
        }

        public String toString() {
            return "LanguageModelQueryRouter.LanguageModelQueryRouterBuilder(chatLanguageModel=" + this.chatLanguageModel + ", retrieverToDescription=" + this.retrieverToDescription + ", promptTemplate=" + this.promptTemplate + ", fallbackStrategy=" + this.fallbackStrategy + ")";
        }
    }

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map<ContentRetriever, String> map) {
        this(chatLanguageModel, map, DEFAULT_PROMPT_TEMPLATE, FallbackStrategy.DO_NOT_ROUTE);
    }

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel, Map<ContentRetriever, String> map, PromptTemplate promptTemplate, FallbackStrategy fallbackStrategy) {
        this.chatLanguageModel = (ChatLanguageModel) ValidationUtils.ensureNotNull(chatLanguageModel, "chatLanguageModel");
        ValidationUtils.ensureNotEmpty(map, "retrieverToDescription");
        this.promptTemplate = (PromptTemplate) Utils.getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        HashMap hashMap = new HashMap();
        StringBuilder sb = new StringBuilder();
        int i = 1;
        for (Map.Entry<ContentRetriever, String> entry : map.entrySet()) {
            hashMap.put(Integer.valueOf(i), (ContentRetriever) ValidationUtils.ensureNotNull(entry.getKey(), "ContentRetriever"));
            if (i > 1) {
                sb.append("\n");
            }
            sb.append(i);
            sb.append(": ");
            sb.append(ValidationUtils.ensureNotBlank(entry.getValue(), "ContentRetriever description"));
            i++;
        }
        this.idToRetriever = hashMap;
        this.options = sb.toString();
        this.fallbackStrategy = (FallbackStrategy) Utils.getOrDefault(fallbackStrategy, FallbackStrategy.DO_NOT_ROUTE);
    }

    @Override // dev.langchain4j.rag.query.router.QueryRouter
    public Collection<ContentRetriever> route(Query query) {
        try {
            return parse(this.chatLanguageModel.generate(createPrompt(query).text()));
        } catch (Exception e) {
            log.warn("Failed to route query '{}'", query.text(), e);
            return fallback(query, e);
        }
    }

    protected Collection<ContentRetriever> fallback(Query query, Exception exc) {
        switch (this.fallbackStrategy) {
            case DO_NOT_ROUTE:
                log.debug("Fallback: query '{}' will not be routed", query.text());
                return Collections.emptyList();
            case ROUTE_TO_ALL:
                log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text());
                return new ArrayList(this.idToRetriever.values());
            case FAIL:
            default:
                throw new RuntimeException(exc);
        }
    }

    protected Prompt createPrompt(Query query) {
        HashMap hashMap = new HashMap();
        hashMap.put("query", query.text());
        hashMap.put("options", this.options);
        return this.promptTemplate.apply((Map<String, Object>) hashMap);
    }

    protected Collection<ContentRetriever> parse(String str) {
        Stream map = Arrays.stream(str.split(",")).map((v0) -> {
            return v0.trim();
        }).map(Integer::parseInt);
        Map<Integer, ContentRetriever> map2 = this.idToRetriever;
        Objects.requireNonNull(map2);
        return (Collection) map.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList());
    }

    public static LanguageModelQueryRouterBuilder builder() {
        return new LanguageModelQueryRouterBuilder();
    }
}
