package io.github.amithkoujalgi.ollama4j.core;

import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolNotFoundException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.ListModelsResponse;
import io.github.amithkoujalgi.ollama4j.core.models.Model;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.ModelPullResponse;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultStreamer;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller;
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolFunction;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolFunctionCallSpec;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolRegistry;
import io.github.amithkoujalgi.ollama4j.core.tools.Tools;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
import java.net.http.HttpConnectTimeoutException;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:META-INF/jars/ollama4j-1.0.77.jar:io/github/amithkoujalgi/ollama4j/core/OllamaAPI.class */
public class OllamaAPI {
    private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
    private final String host;
    private BasicAuth basicAuth;
    private long requestTimeoutSeconds = 10;
    private boolean verbose = true;
    private final ToolRegistry toolRegistry = new ToolRegistry();

    public OllamaAPI(String str) {
        if (str.endsWith("/")) {
            this.host = str.substring(0, str.length() - 1);
        } else {
            this.host = str;
        }
    }

    public void setBasicAuth(String str, String str2) {
        this.basicAuth = new BasicAuth(str, str2);
    }

    public boolean ping() {
        try {
            try {
                return HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/tags")).header("Accept", "application/json").header("Content-type", "application/json").GET().build(), HttpResponse.BodyHandlers.ofString()).statusCode() == 200;
            } catch (IOException | InterruptedException e) {
                throw new RuntimeException(e);
            } catch (HttpConnectTimeoutException e2) {
                return false;
            }
        } catch (URISyntaxException e3) {
            throw new RuntimeException(e3);
        }
    }

    public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/tags")).header("Accept", "application/json").header("Content-type", "application/json").GET().build(), HttpResponse.BodyHandlers.ofString());
        int statusCode = send.statusCode();
        String str = (String) send.body();
        if (statusCode == 200) {
            return ((ListModelsResponse) Utils.getObjectMapper().readValue(str, ListModelsResponse.class)).getModels();
        }
        throw new OllamaBaseException(statusCode + " - " + str);
    }

    public void pullModel(String str) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/pull")).POST(HttpRequest.BodyPublishers.ofString(new ModelRequest(str).toString())).header("Accept", "application/json").header("Content-type", "application/json").build(), HttpResponse.BodyHandlers.ofInputStream());
        int statusCode = send.statusCode();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader((InputStream) send.body(), StandardCharsets.UTF_8));
        while (true) {
            try {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                ModelPullResponse modelPullResponse = (ModelPullResponse) Utils.getObjectMapper().readValue(readLine, ModelPullResponse.class);
                if (this.verbose) {
                    logger.info(modelPullResponse.getStatus());
                }
            } catch (Throwable th) {
                try {
                    bufferedReader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
        bufferedReader.close();
        if (statusCode != 200) {
            throw new OllamaBaseException(statusCode + " - " + "");
        }
    }

    public ModelDetail getModelDetails(String str) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/show")).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(new ModelRequest(str).toString())).build(), HttpResponse.BodyHandlers.ofString());
        int statusCode = send.statusCode();
        String str2 = (String) send.body();
        if (statusCode == 200) {
            return (ModelDetail) Utils.getObjectMapper().readValue(str2, ModelDetail.class);
        }
        throw new OllamaBaseException(statusCode + " - " + str2);
    }

    public void createModelWithFilePath(String str, String str2) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/create")).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(new CustomModelFilePathRequest(str, str2).toString(), StandardCharsets.UTF_8)).build(), HttpResponse.BodyHandlers.ofString());
        int statusCode = send.statusCode();
        String str3 = (String) send.body();
        if (statusCode != 200) {
            throw new OllamaBaseException(statusCode + " - " + str3);
        }
        if (str3.contains("error")) {
            throw new OllamaBaseException(str3);
        }
        if (this.verbose) {
            logger.info(str3);
        }
    }

    public void createModelWithModelFileContents(String str, String str2) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/create")).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(new CustomModelFileContentsRequest(str, str2).toString(), StandardCharsets.UTF_8)).build(), HttpResponse.BodyHandlers.ofString());
        int statusCode = send.statusCode();
        String str3 = (String) send.body();
        if (statusCode != 200) {
            throw new OllamaBaseException(statusCode + " - " + str3);
        }
        if (str3.contains("error")) {
            throw new OllamaBaseException(str3);
        }
        if (this.verbose) {
            logger.info(str3);
        }
    }

    public void deleteModel(String str, boolean z) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(new URI(this.host + "/api/delete")).method("DELETE", HttpRequest.BodyPublishers.ofString(new ModelRequest(str).toString(), StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build(), HttpResponse.BodyHandlers.ofString());
        int statusCode = send.statusCode();
        String str2 = (String) send.body();
        if ((statusCode != 404 || !str2.contains("model") || !str2.contains("not found")) && statusCode != 200) {
            throw new OllamaBaseException(statusCode + " - " + str2);
        }
    }

    public List<Double> generateEmbeddings(String str, String str2) throws IOException, InterruptedException, OllamaBaseException {
        return generateEmbeddings(new OllamaEmbeddingsRequestModel(str, str2));
    }

    public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel ollamaEmbeddingsRequestModel) throws IOException, InterruptedException, OllamaBaseException {
        HttpResponse send = HttpClient.newHttpClient().send(getRequestBuilderDefault(URI.create(this.host + "/api/embeddings")).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(ollamaEmbeddingsRequestModel.toString())).build(), HttpResponse.BodyHandlers.ofString());
        int statusCode = send.statusCode();
        String str = (String) send.body();
        if (statusCode == 200) {
            return ((OllamaEmbeddingResponseModel) Utils.getObjectMapper().readValue(str, OllamaEmbeddingResponseModel.class)).getEmbedding();
        }
        throw new OllamaBaseException(statusCode + " - " + str);
    }

    public OllamaResult generate(String str, String str2, boolean z, Options options, OllamaStreamHandler ollamaStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
        OllamaGenerateRequestModel ollamaGenerateRequestModel = new OllamaGenerateRequestModel(str, str2);
        ollamaGenerateRequestModel.setRaw(z);
        ollamaGenerateRequestModel.setOptions(options.getOptionsMap());
        return generateSyncForOllamaRequestModel(ollamaGenerateRequestModel, ollamaStreamHandler);
    }

    public OllamaResult generate(String str, String str2, boolean z, Options options) throws OllamaBaseException, IOException, InterruptedException {
        return generate(str, str2, z, options, null);
    }

    public OllamaToolsResult generateWithTools(String str, String str2, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
        OllamaToolsResult ollamaToolsResult = new OllamaToolsResult();
        HashMap hashMap = new HashMap();
        OllamaResult generate = generate(str, str2, true, options, null);
        ollamaToolsResult.setModelResult(generate);
        String response = generate.getResponse();
        if (response.contains("[TOOL_CALLS]")) {
            response = response.replace("[TOOL_CALLS]", "");
        }
        for (ToolFunctionCallSpec toolFunctionCallSpec : (List) Utils.getObjectMapper().readValue(response, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class))) {
            hashMap.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
        }
        ollamaToolsResult.setToolResults(hashMap);
        return ollamaToolsResult;
    }

    public OllamaAsyncResultStreamer generateAsync(String str, String str2, boolean z) {
        OllamaGenerateRequestModel ollamaGenerateRequestModel = new OllamaGenerateRequestModel(str, str2);
        ollamaGenerateRequestModel.setRaw(z);
        OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(URI.create(this.host + "/api/generate")), ollamaGenerateRequestModel, this.requestTimeoutSeconds);
        ollamaAsyncResultStreamer.start();
        return ollamaAsyncResultStreamer;
    }

    public OllamaResult generateWithImageFiles(String str, String str2, List<File> list, Options options, OllamaStreamHandler ollamaStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
        ArrayList arrayList = new ArrayList();
        Iterator<File> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(encodeFileToBase64(it.next()));
        }
        OllamaGenerateRequestModel ollamaGenerateRequestModel = new OllamaGenerateRequestModel(str, str2, arrayList);
        ollamaGenerateRequestModel.setOptions(options.getOptionsMap());
        return generateSyncForOllamaRequestModel(ollamaGenerateRequestModel, ollamaStreamHandler);
    }

    public OllamaResult generateWithImageFiles(String str, String str2, List<File> list, Options options) throws OllamaBaseException, IOException, InterruptedException {
        return generateWithImageFiles(str, str2, list, options, null);
    }

    public OllamaResult generateWithImageURLs(String str, String str2, List<String> list, Options options, OllamaStreamHandler ollamaStreamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(it.next())));
        }
        OllamaGenerateRequestModel ollamaGenerateRequestModel = new OllamaGenerateRequestModel(str, str2, arrayList);
        ollamaGenerateRequestModel.setOptions(options.getOptionsMap());
        return generateSyncForOllamaRequestModel(ollamaGenerateRequestModel, ollamaStreamHandler);
    }

    public OllamaResult generateWithImageURLs(String str, String str2, List<String> list, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
        return generateWithImageURLs(str, str2, list, options, null);
    }

    public OllamaChatResult chat(String str, List<OllamaChatMessage> list) throws OllamaBaseException, IOException, InterruptedException {
        return chat(OllamaChatRequestBuilder.getInstance(str).withMessages(list).build());
    }

    public OllamaChatResult chat(OllamaChatRequestModel ollamaChatRequestModel) throws OllamaBaseException, IOException, InterruptedException {
        return chat(ollamaChatRequestModel, (OllamaStreamHandler) null);
    }

    public OllamaChatResult chat(OllamaChatRequestModel ollamaChatRequestModel, OllamaStreamHandler ollamaStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
        OllamaResult callSync;
        OllamaChatEndpointCaller ollamaChatEndpointCaller = new OllamaChatEndpointCaller(this.host, this.basicAuth, this.requestTimeoutSeconds, this.verbose);
        if (ollamaStreamHandler != null) {
            ollamaChatRequestModel.setStream(true);
            callSync = ollamaChatEndpointCaller.call(ollamaChatRequestModel, ollamaStreamHandler);
        } else {
            callSync = ollamaChatEndpointCaller.callSync(ollamaChatRequestModel);
        }
        return new OllamaChatResult(callSync.getResponse(), callSync.getResponseTime(), callSync.getHttpStatusCode(), ollamaChatRequestModel.getMessages());
    }

    public void registerTool(Tools.ToolSpecification toolSpecification) {
        this.toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
    }

    private static String encodeFileToBase64(File file) throws IOException {
        return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
    }

    private static String encodeByteArrayToBase64(byte[] bArr) {
        return Base64.getEncoder().encodeToString(bArr);
    }

    private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaGenerateRequestModel, OllamaStreamHandler ollamaStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
        OllamaResult callSync;
        OllamaGenerateEndpointCaller ollamaGenerateEndpointCaller = new OllamaGenerateEndpointCaller(this.host, this.basicAuth, this.requestTimeoutSeconds, this.verbose);
        if (ollamaStreamHandler != null) {
            ollamaGenerateRequestModel.setStream(true);
            callSync = ollamaGenerateEndpointCaller.call(ollamaGenerateRequestModel, ollamaStreamHandler);
        } else {
            callSync = ollamaGenerateEndpointCaller.callSync(ollamaGenerateRequestModel);
        }
        return callSync;
    }

    private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
        HttpRequest.Builder timeout = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
        if (isBasicAuthCredentialsSet()) {
            timeout.header("Authorization", getBasicAuthHeaderValue());
        }
        return timeout;
    }

    private String getBasicAuthHeaderValue() {
        return "Basic " + Base64.getEncoder().encodeToString((this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword()).getBytes());
    }

    private boolean isBasicAuthCredentialsSet() {
        return this.basicAuth != null;
    }

    private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
        try {
            String name = toolFunctionCallSpec.getName();
            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
            ToolFunction function = this.toolRegistry.getFunction(name);
            if (this.verbose) {
                logger.debug("Invoking function {} with arguments {}", name, arguments);
            }
            if (function == null) {
                throw new ToolNotFoundException("No such tool: " + name);
            }
            return function.apply(arguments);
        } catch (Exception e) {
            throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
        }
    }

    public void setRequestTimeoutSeconds(long j) {
        this.requestTimeoutSeconds = j;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }
}
