/*
 * Decompiled with CFR 0.152.
 */
package me.sshcrack.mc_talking.manager;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.minecolonies.api.colony.IColony;
import com.minecolonies.api.entity.citizen.AbstractEntityCitizen;
import de.maxhenkel.voicechat.api.audiochannel.AudioChannel;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Timer;
import java.util.TimerTask;
import java.util.UUID;
import me.sshcrack.mc_talking.ConversationManager;
import me.sshcrack.mc_talking.McTalking;
import me.sshcrack.mc_talking.ModAttachmentTypes;
import me.sshcrack.mc_talking.config.AvailableAI;
import me.sshcrack.mc_talking.config.McTalkingConfig;
import me.sshcrack.mc_talking.config.ModalityModes;
import me.sshcrack.mc_talking.gson.BidiGenerateContentSetup;
import me.sshcrack.mc_talking.gson.BidiGenerateContentToolResponse;
import me.sshcrack.mc_talking.gson.ClientMessages;
import me.sshcrack.mc_talking.gson.RealtimeInput;
import me.sshcrack.mc_talking.manager.CitizenContextUtils;
import me.sshcrack.mc_talking.manager.GeminiStream;
import me.sshcrack.mc_talking.manager.TalkingManager;
import me.sshcrack.mc_talking.manager.tools.AITools;
import me.sshcrack.mc_talking.manager.tools.FunctionAction;
import me.sshcrack.mc_talking.network.AiStatus;
import me.sshcrack.mc_talking.network.AiStatusPayload;
import net.minecraft.network.chat.Component;
import net.minecraft.network.protocol.common.custom.CustomPacketPayload;
import net.minecraft.server.level.ServerPlayer;
import net.neoforged.neoforge.network.PacketDistributor;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;

public class GeminiWsClient
extends WebSocketClient {
    boolean setupComplete;
    boolean isInitiatingConnection = false;
    boolean shouldReconnect = false;
    public static boolean quotaExceeded = false;
    GeminiStream stream;
    ServerPlayer initialPlayer;
    TalkingManager manager;
    private final List<short[]> pending_prompt = new ArrayList<short[]>();
    private final List<String> pendingSystemText = new ArrayList<String>();
    private final List<short[]> audioBatch = Collections.synchronizedList(new ArrayList());
    private static final long BATCH_TIMEOUT = 100L;
    private static final int MAX_BATCH_SIZE = 5;
    private volatile Timer batchTimer;
    private volatile TimerTask currentBatchTask;
    private final Object batchLock = new Object();
    private String currMsg = "";
    boolean sentGeneratingStatus = false;
    long lastReconnectTime = 0L;

    private static String getUrl() {
        return "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent?key=" + (String)McTalkingConfig.CONFIG.geminiApiKey.get();
    }

    public GeminiWsClient(TalkingManager manager, ServerPlayer player) {
        super(URI.create(GeminiWsClient.getUrl()));
        this.manager = manager;
        this.stream = new GeminiStream((AudioChannel)manager.channel);
        boolean isFemale = manager.entity.getCitizenData().isFemale();
        boolean isChild = manager.entity.getCitizenData().isChild();
        if (isChild && !isFemale) {
            this.stream.setPitch(0.8f);
        }
        this.initialPlayer = player;
        PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(manager.entity.getUUID(), AiStatus.LISTENING), (CustomPacketPayload[])new CustomPacketPayload[0]);
    }

    public void onOpen(ServerHandshake handshakeData) {
        this.isInitiatingConnection = false;
        BidiGenerateContentSetup setup = new BidiGenerateContentSetup("models/" + ((AvailableAI)((Object)McTalkingConfig.CONFIG.currentAiModel.get())).getName());
        ModalityModes modality = (ModalityModes)((Object)McTalkingConfig.CONFIG.modality.get());
        setup.generationConfig.responseModalities = modality.getModes();
        if (McTalkingConfig.CONFIG.currentAiModel.get() == AvailableAI.Flash2_5) {
            setup.generationConfig.responseModalities = List.of("AUDIO");
        }
        if (modality != ModalityModes.TEXT) {
            String sessionToken;
            setup.generationConfig.speechConfig = new BidiGenerateContentSetup.GenerationConfig.SpeechConfig();
            setup.generationConfig.speechConfig.language_code = (String)McTalkingConfig.CONFIG.language.get();
            AbstractEntityCitizen entity = this.manager.entity;
            boolean female = entity.getCitizenData().isFemale();
            UUID uuid = entity.getUUID();
            setup.sessionResumption = new BidiGenerateContentSetup.SessionResumptionConfig();
            if (entity.hasData(ModAttachmentTypes.SESSION_TOKEN) && !(sessionToken = (String)entity.getData(ModAttachmentTypes.SESSION_TOKEN)).isBlank()) {
                setup.sessionResumption = new BidiGenerateContentSetup.SessionResumptionConfig(sessionToken);
            }
            setup.generationConfig.speechConfig.voice_config = new BidiGenerateContentSetup.GenerationConfig.SpeechConfig.VoiceConfig();
            setup.generationConfig.speechConfig.voice_config.prebuiltVoiceConfig = new BidiGenerateContentSetup.GenerationConfig.SpeechConfig.PrebuiltVoiceConfig();
            setup.generationConfig.speechConfig.voice_config.prebuiltVoiceConfig.voice_name = ((AvailableAI)((Object)McTalkingConfig.CONFIG.currentAiModel.get())).getRandomVoice(uuid, female);
        }
        setup.realtimeInputConfig = new BidiGenerateContentSetup.RealtimeInputConfig();
        BidiGenerateContentSetup.SystemInstruction sys = new BidiGenerateContentSetup.SystemInstruction();
        String prompt = CitizenContextUtils.generateCitizenRoleplayPrompt(this.manager.entity.getCitizenData(), this.initialPlayer);
        BidiGenerateContentSetup.SystemInstruction.Part p = new BidiGenerateContentSetup.SystemInstruction.Part(prompt);
        sys.parts.add(p);
        setup.systemInstruction = sys;
        setup.tools.addAll(AITools.getAllTools());
        if (McTalkingConfig.CONFIG.currentAiModel.get() == AvailableAI.Flash2_5 && ((Boolean)McTalkingConfig.CONFIG.enableFunctionWorkaround.get()).booleanValue()) {
            setup.tools.add(BidiGenerateContentSetup.Tool.googleSearch());
        }
        this.send(ClientMessages.setup(setup));
    }

    public void onMessage(String message) {
        JsonElement p = JsonParser.parseString((String)message);
        if (!p.isJsonObject()) {
            return;
        }
        JsonObject outer = p.getAsJsonObject();
        if (outer.has("setupComplete")) {
            McTalking.LOGGER.info("Gemini setup complete");
            this.setupComplete = true;
            if (!this.pendingSystemText.isEmpty()) {
                for (String text : this.pendingSystemText) {
                    this.addSystemText(text);
                }
                this.pendingSystemText.clear();
            }
            if (!this.pending_prompt.isEmpty()) {
                for (short[] data : this.pending_prompt) {
                    this.addPromptAudio(data);
                }
                this.pending_prompt.clear();
            }
            return;
        }
        if (!this.setupComplete) {
            return;
        }
        if (outer.has("usageMetadata")) {
            McTalking.LOGGER.info("Gemini usage metadata: {}", (Object)outer.get("usageMetadata").toString());
        }
        if (outer.has("sessionResumptionUpdate")) {
            JsonObject obj = outer.get("sessionResumptionUpdate").getAsJsonObject();
            if (!obj.has("newHandle") || !obj.get("newHandle").isJsonPrimitive()) {
                return;
            }
            if (!obj.has("resumable") || !obj.get("resumable").getAsBoolean()) {
                return;
            }
            String handle = obj.get("newHandle").getAsString();
            this.manager.entity.setData(ModAttachmentTypes.SESSION_TOKEN, (Object)handle);
            return;
        }
        if (outer.has("toolCall")) {
            System.out.println("Tool call: " + message);
            JsonObject obj = outer.getAsJsonObject("toolCall");
            if (!obj.has("functionCalls") || !obj.get("functionCalls").isJsonArray()) {
                return;
            }
            JsonArray functionCalls = obj.getAsJsonArray("functionCalls");
            for (JsonElement fnCall : functionCalls) {
                JsonObject objFnCall;
                if (!fnCall.isJsonObject() || !(objFnCall = fnCall.getAsJsonObject()).has("name") || !objFnCall.get("name").isJsonPrimitive()) continue;
                String name = objFnCall.get("name").getAsString();
                FunctionAction action = AITools.registeredFunctions.get(name);
                if (action == null) {
                    McTalking.LOGGER.warn("Unknown function call: {}", (Object)name);
                    continue;
                }
                JsonObject args = null;
                if (objFnCall.has("args")) {
                    args = objFnCall.getAsJsonObject("args");
                }
                IColony colony = this.manager.entity.getCitizenColonyHandler().getColony();
                JsonObject output = action.execute(this.manager.entity, colony, args);
                BidiGenerateContentToolResponse res = new BidiGenerateContentToolResponse();
                res.functionResponses.add(new BidiGenerateContentToolResponse.FunctionResponse(objFnCall.get("id").getAsString(), name, output));
                this.send(ClientMessages.response(res));
            }
            return;
        }
        if (outer.has("serverContent") && outer.get("serverContent").isJsonObject()) {
            JsonObject obj = outer.getAsJsonObject("serverContent");
            if (obj.has("turnComplete") && obj.get("turnComplete").getAsBoolean()) {
                McTalking.LOGGER.info("Gemini turn complete");
                PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.LISTENING), (CustomPacketPayload[])new CustomPacketPayload[0]);
                return;
            }
            if (obj.has("interrupted") && obj.get("interrupted").getAsBoolean()) {
                McTalking.LOGGER.info("Gemini generation interrupted");
                this.stream.stop();
                UUID player = ConversationManager.getPlayerForEntity(this.manager.entity.getUUID());
                if (player == null) {
                    return;
                }
                ServerPlayer sPlayer = this.initialPlayer.server.getPlayerList().getPlayer(player);
                if (sPlayer == null || this.currMsg.isBlank()) {
                    return;
                }
                sPlayer.sendSystemMessage((Component)this.manager.entity.getDisplayName().copy().append(": ").append((Component)Component.literal((String)this.currMsg)));
                return;
            }
            if (obj.has("generationComplete") && obj.get("generationComplete").getAsBoolean()) {
                McTalking.LOGGER.info("Gemini generation complete");
                PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.TALKING), (CustomPacketPayload[])new CustomPacketPayload[0]);
                this.stream.flushAudio();
                UUID player = ConversationManager.getPlayerForEntity(this.manager.entity.getUUID());
                if (player == null) {
                    return;
                }
                ServerPlayer sPlayer = this.initialPlayer.server.getPlayerList().getPlayer(player);
                if (sPlayer == null || this.currMsg.isBlank()) {
                    return;
                }
                sPlayer.sendSystemMessage((Component)this.manager.entity.getDisplayName().copy().append(": ").append((Component)Component.literal((String)this.currMsg)));
                return;
            }
            if (obj.has("modelTurn")) {
                JsonObject modelTurn = obj.getAsJsonObject("modelTurn");
                if (modelTurn.has("parts")) {
                    JsonArray parts = modelTurn.getAsJsonArray("parts");
                    for (JsonElement part : parts) {
                        JsonObject inlineData;
                        if (!part.isJsonObject()) continue;
                        JsonObject pObj = part.getAsJsonObject();
                        if (pObj.has("text") && pObj.get("text").isJsonPrimitive()) {
                            String text = pObj.get("text").getAsString();
                            this.currMsg = this.currMsg + text;
                        }
                        if (!pObj.has("inlineData") || !pObj.get("inlineData").isJsonObject() || !(inlineData = pObj.getAsJsonObject("inlineData")).has("data") || !inlineData.get("data").isJsonPrimitive()) continue;
                        String mimeType = inlineData.get("mimeType").getAsString();
                        if (!mimeType.contains("audio/pcm")) {
                            McTalking.LOGGER.warn("Invalid mime type: {}", (Object)inlineData.get("mimeType").getAsString());
                            continue;
                        }
                        String sampleRateStr = mimeType.split("rate=")[1];
                        int sampleRate = Integer.parseInt(sampleRateStr);
                        byte[] data = Base64.getDecoder().decode(inlineData.get("data").getAsString());
                        this.stream.addGeminiPcmWithPitch(data, sampleRate);
                    }
                }
            } else {
                McTalking.LOGGER.warn("Unknown message: {}", (Object)message);
            }
        }
    }

    public void onMessage(ByteBuffer bytes) {
        String newContent = new String(bytes.array(), StandardCharsets.UTF_8);
        this.onMessage(newContent);
    }

    public void onClose(int code, String reason, boolean remote) {
        this.isInitiatingConnection = false;
        if (reason.contains("You exceeded your current quota, please")) {
            quotaExceeded = true;
            McTalking.LOGGER.warn("Quota exceeded for Gemini API, please check your API key and usage limits.");
            PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.QUOTA_EXCEEDED), (CustomPacketPayload[])new CustomPacketPayload[0]);
        }
        if (reason.contains("BidiGenerateContent session not found")) {
            this.manager.entity.setData(ModAttachmentTypes.SESSION_TOKEN, (Object)"");
            new Thread(() -> {
                if (!this.isOpen() || !this.isInitiatingConnection) {
                    this.reconnect();
                    this.isInitiatingConnection = true;
                }
            }).start();
            return;
        }
        if (code != 1000 && code != 1001) {
            PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.ERROR), (CustomPacketPayload[])new CustomPacketPayload[0]);
        }
        McTalking.LOGGER.info("GeminiWsClient closed: {} and code {}", (Object)reason, (Object)code);
    }

    public void onError(Exception ex) {
        PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.ERROR), (CustomPacketPayload[])new CustomPacketPayload[0]);
        ex.printStackTrace();
    }

    public void addSystemText(String newStatusPrompt) {
        if (!this.setupComplete || this.isClosed()) {
            this.pendingSystemText.add(newStatusPrompt);
            if (!this.isOpen() && !this.isInitiatingConnection) {
                if (this.shouldReconnect) {
                    if (System.currentTimeMillis() - this.lastReconnectTime < 5000L) {
                        return;
                    }
                    McTalking.LOGGER.warn("Connection lost, attempting to reconnect...");
                    this.lastReconnectTime = System.currentTimeMillis();
                    this.reconnect();
                } else {
                    this.connect();
                    this.shouldReconnect = true;
                }
                this.isInitiatingConnection = true;
            }
            return;
        }
        RealtimeInput input = new RealtimeInput();
        input.text = newStatusPrompt;
        this.send(ClientMessages.input(input));
    }

    public void addPromptAudio(short[] audio) {
        RealtimeInput input = new RealtimeInput();
        input.audio = new RealtimeInput.Blob("audio/pcm;rate=48000", audio);
        if (this.sentGeneratingStatus) {
            PacketDistributor.sendToAllPlayers((CustomPacketPayload)new AiStatusPayload(this.manager.entity.getUUID(), AiStatus.LISTENING), (CustomPacketPayload[])new CustomPacketPayload[0]);
        }
        if (!this.setupComplete || this.isClosed()) {
            this.pending_prompt.add(audio);
            if (!this.isOpen() && !this.isInitiatingConnection) {
                if (this.shouldReconnect) {
                    if (System.currentTimeMillis() - this.lastReconnectTime < 5000L) {
                        return;
                    }
                    McTalking.LOGGER.warn("Connection lost, attempting to reconnect...");
                    this.lastReconnectTime = System.currentTimeMillis();
                    this.reconnect();
                } else {
                    this.connect();
                    this.shouldReconnect = true;
                }
                this.isInitiatingConnection = true;
            }
            return;
        }
        this.send(ClientMessages.input(input));
    }

    public void close() {
        if (this.batchTimer != null) {
            this.batchTimer.cancel();
            this.batchTimer = null;
        }
        if (this.currentBatchTask != null) {
            this.currentBatchTask.cancel();
            this.currentBatchTask = null;
        }
        this.stream.close();
        super.close();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void batchAudio(short[] audio) {
        boolean isFirstElement;
        boolean batchFull;
        Object object = this.batchLock;
        synchronized (object) {
            this.audioBatch.add(audio);
            batchFull = this.audioBatch.size() >= 5;
            isFirstElement = this.audioBatch.size() == 1;
        }
        if (batchFull) {
            this.sendCurrentBatch();
        } else if (isFirstElement) {
            this.scheduleFlushTimer();
        }
    }

    private void scheduleFlushTimer() {
        if (this.currentBatchTask != null) {
            this.currentBatchTask.cancel();
        }
        this.currentBatchTask = new TimerTask(){

            @Override
            public void run() {
                if (!GeminiWsClient.this.audioBatch.isEmpty()) {
                    GeminiWsClient.this.sendCurrentBatch();
                }
            }
        };
        if (this.batchTimer == null) {
            this.batchTimer = new Timer("AudioBatchTimer", true);
        }
        this.batchTimer.schedule(this.currentBatchTask, 100L);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void sendCurrentBatch() {
        ArrayList<short[]> batchCopy;
        Object object = this.batchLock;
        synchronized (object) {
            if (this.audioBatch.isEmpty()) {
                return;
            }
            batchCopy = new ArrayList<short[]>(this.audioBatch);
            this.audioBatch.clear();
        }
        int totalLength = 0;
        for (short[] audioData : batchCopy) {
            totalLength += audioData.length;
        }
        short[] combinedAudio = new short[totalLength];
        int position = 0;
        for (short[] audioData : batchCopy) {
            System.arraycopy(audioData, 0, combinedAudio, position, audioData.length);
            position += audioData.length;
        }
        this.addPromptAudio(combinedAudio);
    }
}

