package me.sshcrack.mc_talking.manager;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.minecolonies.api.entity.citizen.AbstractEntityCitizen;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Timer;
import java.util.TimerTask;
import java.util.UUID;
import me.sshcrack.mc_talking.McTalking;
import me.sshcrack.mc_talking.ModAttachmentTypes;
import me.sshcrack.mc_talking.config.McTalkingConfig;
import me.sshcrack.mc_talking.gson.BidiGenerateContentSetup;
import me.sshcrack.mc_talking.gson.BidiGenerateContentToolResponse;
import me.sshcrack.mc_talking.gson.ClientMessages;
import me.sshcrack.mc_talking.gson.RealtimeInput;
import me.sshcrack.mc_talking.manager.tools.AITools;
import me.sshcrack.mc_talking.manager.tools.FunctionAction;
import me.sshcrack.mc_talking.network.AiStatus;
import me.sshcrack.mc_talking.network.AiStatusPayload;
import net.minecraft.network.protocol.common.custom.CustomPacketPayload;
import net.minecraft.server.level.ServerPlayer;
import net.neoforged.neoforge.network.PacketDistributor;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.extensions.ExtensionRequestData;
import org.java_websocket.handshake.ServerHandshake;

/* loaded from: input_file:me/sshcrack/mc_talking/manager/GeminiWsClient.class */
public class GeminiWsClient extends WebSocketClient {
    boolean setupComplete;
    boolean isInitiatingConnection;
    boolean shouldReconnect;
    public static boolean quotaExceeded = false;
    GeminiStream stream;
    ServerPlayer initialPlayer;
    TalkingManager manager;
    private final List<short[]> pending_prompt;
    private final List<String> pendingSystemText;
    private final List<short[]> audioBatch;
    private static final long BATCH_TIMEOUT = 100;
    private static final int MAX_BATCH_SIZE = 5;
    private volatile Timer batchTimer;
    private volatile TimerTask currentBatchTask;
    private final Object batchLock;
    boolean sentGeneratingStatus;
    long lastReconnectTime;

    private static String getUrl() {
        return "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent?key=" + McTalkingConfig.geminiApiKey;
    }

    public GeminiWsClient(TalkingManager talkingManager, ServerPlayer serverPlayer) {
        super(URI.create(getUrl()));
        this.isInitiatingConnection = false;
        this.shouldReconnect = false;
        this.pending_prompt = new ArrayList();
        this.pendingSystemText = new ArrayList();
        this.audioBatch = Collections.synchronizedList(new ArrayList());
        this.batchLock = new Object();
        this.sentGeneratingStatus = false;
        this.lastReconnectTime = 0L;
        this.manager = talkingManager;
        this.stream = new GeminiStream(talkingManager.channel);
        boolean isFemale = talkingManager.entity.getCitizenData().isFemale();
        if (talkingManager.entity.getCitizenData().isChild() && !isFemale) {
            this.stream.setPitch(0.8f);
        }
        this.initialPlayer = serverPlayer;
        PacketDistributor.sendToAllPlayers(new AiStatusPayload(talkingManager.entity.getUUID(), AiStatus.LISTENING), new CustomPacketPayload[0]);
    }

    @Override // org.java_websocket.client.WebSocketClient
    public void onOpen(ServerHandshake serverHandshake) {
        this.isInitiatingConnection = false;
        BidiGenerateContentSetup bidiGenerateContentSetup = new BidiGenerateContentSetup("models/" + McTalkingConfig.currentAIModel.getName());
        bidiGenerateContentSetup.generationConfig.responseModalities = List.of("AUDIO");
        bidiGenerateContentSetup.generationConfig.speechConfig = new BidiGenerateContentSetup.GenerationConfig.SpeechConfig();
        bidiGenerateContentSetup.generationConfig.speechConfig.language_code = McTalkingConfig.language;
        AbstractEntityCitizen abstractEntityCitizen = this.manager.entity;
        boolean isFemale = abstractEntityCitizen.getCitizenData().isFemale();
        UUID uuid = abstractEntityCitizen.getUUID();
        bidiGenerateContentSetup.sessionResumption = new BidiGenerateContentSetup.SessionResumptionConfig();
        if (abstractEntityCitizen.hasData(ModAttachmentTypes.SESSION_TOKEN)) {
            String str = (String) abstractEntityCitizen.getData(ModAttachmentTypes.SESSION_TOKEN);
            if (!str.isBlank()) {
                bidiGenerateContentSetup.sessionResumption = new BidiGenerateContentSetup.SessionResumptionConfig(str);
            }
        }
        bidiGenerateContentSetup.generationConfig.speechConfig.voice_config = new BidiGenerateContentSetup.GenerationConfig.SpeechConfig.VoiceConfig();
        bidiGenerateContentSetup.generationConfig.speechConfig.voice_config.prebuiltVoiceConfig = new BidiGenerateContentSetup.GenerationConfig.SpeechConfig.PrebuiltVoiceConfig();
        bidiGenerateContentSetup.generationConfig.speechConfig.voice_config.prebuiltVoiceConfig.voice_name = McTalkingConfig.currentAIModel.getRandomVoice(uuid, isFemale);
        bidiGenerateContentSetup.realtimeInputConfig = new BidiGenerateContentSetup.RealtimeInputConfig();
        BidiGenerateContentSetup.SystemInstruction systemInstruction = new BidiGenerateContentSetup.SystemInstruction();
        systemInstruction.parts.add(new BidiGenerateContentSetup.SystemInstruction.Part(CitizenContextUtils.generateCitizenRoleplayPrompt(this.manager.entity.getCitizenData(), this.initialPlayer)));
        bidiGenerateContentSetup.systemInstruction = systemInstruction;
        bidiGenerateContentSetup.tools.addAll(AITools.getAllTools());
        send(ClientMessages.setup(bidiGenerateContentSetup));
    }

    @Override // org.java_websocket.client.WebSocketClient
    public void onMessage(String str) {
        JsonElement parseString = JsonParser.parseString(str);
        if (parseString.isJsonObject()) {
            JsonObject asJsonObject = parseString.getAsJsonObject();
            if (asJsonObject.has("setupComplete")) {
                McTalking.LOGGER.info("Gemini setup complete");
                this.setupComplete = true;
                if (!this.pendingSystemText.isEmpty()) {
                    Iterator<String> it = this.pendingSystemText.iterator();
                    while (it.hasNext()) {
                        addSystemText(it.next());
                    }
                    this.pendingSystemText.clear();
                }
                if (this.pending_prompt.isEmpty()) {
                    return;
                }
                Iterator<short[]> it2 = this.pending_prompt.iterator();
                while (it2.hasNext()) {
                    addPromptAudio(it2.next());
                }
                this.pending_prompt.clear();
                return;
            }
            if (this.setupComplete) {
                if (asJsonObject.has("usageMetadata")) {
                    McTalking.LOGGER.info("Gemini usage metadata: {}", asJsonObject.get("usageMetadata").toString());
                }
                if (asJsonObject.has("sessionResumptionUpdate")) {
                    JsonObject asJsonObject2 = asJsonObject.get("sessionResumptionUpdate").getAsJsonObject();
                    if (asJsonObject2.has("newHandle") && asJsonObject2.get("newHandle").isJsonPrimitive() && asJsonObject2.has("resumable") && asJsonObject2.get("resumable").getAsBoolean()) {
                        this.manager.entity.setData(ModAttachmentTypes.SESSION_TOKEN, asJsonObject2.get("newHandle").getAsString());
                        return;
                    }
                    return;
                }
                if (asJsonObject.has("toolCall")) {
                    System.out.println("Tool call: " + str);
                    JsonObject asJsonObject3 = asJsonObject.getAsJsonObject("toolCall");
                    if (asJsonObject3.has("functionCalls") && asJsonObject3.get("functionCalls").isJsonArray()) {
                        Iterator it3 = asJsonObject3.getAsJsonArray("functionCalls").iterator();
                        while (it3.hasNext()) {
                            JsonElement jsonElement = (JsonElement) it3.next();
                            if (jsonElement.isJsonObject()) {
                                JsonObject asJsonObject4 = jsonElement.getAsJsonObject();
                                if (asJsonObject4.has("name") && asJsonObject4.get("name").isJsonPrimitive()) {
                                    String asString = asJsonObject4.get("name").getAsString();
                                    FunctionAction functionAction = AITools.registeredFunctions.get(asString);
                                    if (functionAction == null) {
                                        McTalking.LOGGER.warn("Unknown function call: {}", asString);
                                    } else {
                                        JsonObject jsonObject = null;
                                        if (asJsonObject4.has("args")) {
                                            jsonObject = asJsonObject4.getAsJsonObject("args");
                                        }
                                        JsonObject execute = functionAction.execute(this.manager.entity, this.manager.entity.getCitizenColonyHandler().getColony(), jsonObject);
                                        BidiGenerateContentToolResponse bidiGenerateContentToolResponse = new BidiGenerateContentToolResponse();
                                        bidiGenerateContentToolResponse.functionResponses.add(new BidiGenerateContentToolResponse.FunctionResponse(asJsonObject4.get("id").getAsString(), asString, execute));
                                        send(ClientMessages.response(bidiGenerateContentToolResponse));
                                    }
                                }
                            }
                        }
                        return;
                    }
                    return;
                }
                if (asJsonObject.has("serverContent") && asJsonObject.get("serverContent").isJsonObject()) {
                    JsonObject asJsonObject5 = asJsonObject.getAsJsonObject("serverContent");
                    if (asJsonObject5.has("turnComplete") && asJsonObject5.get("turnComplete").getAsBoolean()) {
                        McTalking.LOGGER.info("Gemini turn complete");
                        PacketDistributor.sendToAllPlayers(new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.LISTENING), new CustomPacketPayload[0]);
                        return;
                    }
                    if (asJsonObject5.has("interrupted") && asJsonObject5.get("interrupted").getAsBoolean()) {
                        McTalking.LOGGER.info("Gemini generation interrupted");
                        this.stream.stop();
                        return;
                    }
                    if (asJsonObject5.has("generationComplete") && asJsonObject5.get("generationComplete").getAsBoolean()) {
                        McTalking.LOGGER.info("Gemini generation complete");
                        PacketDistributor.sendToAllPlayers(new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.TALKING), new CustomPacketPayload[0]);
                        this.stream.flushAudio();
                        return;
                    }
                    if (!asJsonObject5.has("modelTurn")) {
                        McTalking.LOGGER.warn("Unknown message: {}", str);
                        return;
                    }
                    JsonObject asJsonObject6 = asJsonObject5.getAsJsonObject("modelTurn");
                    if (asJsonObject6.has("parts")) {
                        Iterator it4 = asJsonObject6.getAsJsonArray("parts").iterator();
                        while (it4.hasNext()) {
                            JsonElement jsonElement2 = (JsonElement) it4.next();
                            if (jsonElement2.isJsonObject()) {
                                JsonObject asJsonObject7 = jsonElement2.getAsJsonObject();
                                if (asJsonObject7.has("inlineData") && asJsonObject7.get("inlineData").isJsonObject()) {
                                    JsonObject asJsonObject8 = asJsonObject7.getAsJsonObject("inlineData");
                                    if (asJsonObject8.has("data") && asJsonObject8.get("data").isJsonPrimitive()) {
                                        String asString2 = asJsonObject8.get("mimeType").getAsString();
                                        if (asString2.contains("audio/pcm")) {
                                            this.stream.addGeminiPcmWithPitch(Base64.getDecoder().decode(asJsonObject8.get("data").getAsString()), Integer.parseInt(asString2.split("rate=")[1]));
                                        } else {
                                            McTalking.LOGGER.warn("Invalid mime type: {}", asJsonObject8.get("mimeType").getAsString());
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    @Override // org.java_websocket.client.WebSocketClient
    public void onMessage(ByteBuffer byteBuffer) {
        onMessage(new String(byteBuffer.array(), StandardCharsets.UTF_8));
    }

    @Override // org.java_websocket.client.WebSocketClient
    public void onClose(int i, String str, boolean z) {
        this.isInitiatingConnection = false;
        if (str.contains("You exceeded your current quota, please")) {
            quotaExceeded = true;
            McTalking.LOGGER.warn("Quota exceeded for Gemini API, please check your API key and usage limits.");
            PacketDistributor.sendToAllPlayers(new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.QUOTA_EXCEEDED), new CustomPacketPayload[0]);
        }
        if (str.contains("BidiGenerateContent session not found")) {
            this.manager.entity.setData(ModAttachmentTypes.SESSION_TOKEN, ExtensionRequestData.EMPTY_VALUE);
            new Thread(() -> {
                if (isOpen() && this.isInitiatingConnection) {
                    return;
                }
                reconnect();
                this.isInitiatingConnection = true;
            }).start();
        } else {
            if (i != 1000 && i != 1001) {
                PacketDistributor.sendToAllPlayers(new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.ERROR), new CustomPacketPayload[0]);
            }
            McTalking.LOGGER.info("GeminiWsClient closed: {} and code {}", str, Integer.valueOf(i));
        }
    }

    @Override // org.java_websocket.client.WebSocketClient
    public void onError(Exception exc) {
        PacketDistributor.sendToAllPlayers(new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.ERROR), new CustomPacketPayload[0]);
        exc.printStackTrace();
    }

    public void addSystemText(String str) {
        if (this.setupComplete && !isClosed()) {
            RealtimeInput realtimeInput = new RealtimeInput();
            realtimeInput.text = str;
            send(ClientMessages.input(realtimeInput));
            return;
        }
        this.pendingSystemText.add(str);
        if (isOpen() || this.isInitiatingConnection) {
            return;
        }
        if (!this.shouldReconnect) {
            connect();
            this.shouldReconnect = true;
        } else {
            if (System.currentTimeMillis() - this.lastReconnectTime < 5000) {
                return;
            }
            McTalking.LOGGER.warn("Connection lost, attempting to reconnect...");
            this.lastReconnectTime = System.currentTimeMillis();
            reconnect();
        }
        this.isInitiatingConnection = true;
    }

    public void addPromptAudio(short[] sArr) {
        RealtimeInput realtimeInput = new RealtimeInput();
        realtimeInput.audio = new RealtimeInput.Blob("audio/pcm;rate=48000", sArr);
        if (this.sentGeneratingStatus) {
            PacketDistributor.sendToAllPlayers(new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.LISTENING), new CustomPacketPayload[0]);
        }
        if (this.setupComplete && !isClosed()) {
            send(ClientMessages.input(realtimeInput));
            return;
        }
        this.pending_prompt.add(sArr);
        if (isOpen() || this.isInitiatingConnection) {
            return;
        }
        if (!this.shouldReconnect) {
            connect();
            this.shouldReconnect = true;
        } else {
            if (System.currentTimeMillis() - this.lastReconnectTime < 5000) {
                return;
            }
            McTalking.LOGGER.warn("Connection lost, attempting to reconnect...");
            this.lastReconnectTime = System.currentTimeMillis();
            reconnect();
        }
        this.isInitiatingConnection = true;
    }

    @Override // org.java_websocket.client.WebSocketClient, org.java_websocket.WebSocket
    public void close() {
        if (this.batchTimer != null) {
            this.batchTimer.cancel();
            this.batchTimer = null;
        }
        if (this.currentBatchTask != null) {
            this.currentBatchTask.cancel();
            this.currentBatchTask = null;
        }
        this.stream.close();
        super.close();
    }

    public void batchAudio(short[] sArr) {
        boolean z;
        boolean z2;
        synchronized (this.batchLock) {
            this.audioBatch.add(sArr);
            z = this.audioBatch.size() >= MAX_BATCH_SIZE;
            z2 = this.audioBatch.size() == 1;
        }
        if (z) {
            sendCurrentBatch();
        } else if (z2) {
            scheduleFlushTimer();
        }
    }

    private void scheduleFlushTimer() {
        if (this.currentBatchTask != null) {
            this.currentBatchTask.cancel();
        }
        this.currentBatchTask = new TimerTask() { // from class: me.sshcrack.mc_talking.manager.GeminiWsClient.1
            @Override // java.util.TimerTask, java.lang.Runnable
            public void run() {
                if (GeminiWsClient.this.audioBatch.isEmpty()) {
                    return;
                }
                GeminiWsClient.this.sendCurrentBatch();
            }
        };
        if (this.batchTimer == null) {
            this.batchTimer = new Timer("AudioBatchTimer", true);
        }
        this.batchTimer.schedule(this.currentBatchTask, BATCH_TIMEOUT);
    }

    private void sendCurrentBatch() {
        synchronized (this.batchLock) {
            if (this.audioBatch.isEmpty()) {
                return;
            }
            ArrayList<short[]> arrayList = new ArrayList(this.audioBatch);
            this.audioBatch.clear();
            int i = 0;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                i += ((short[]) it.next()).length;
            }
            short[] sArr = new short[i];
            int i2 = 0;
            for (short[] sArr2 : arrayList) {
                System.arraycopy(sArr2, 0, sArr, i2, sArr2.length);
                i2 += sArr2.length;
            }
            addPromptAudio(sArr);
        }
    }
}
