#!/usr/bin/env python3
# server.py – FastAPI CTranslate2 server

import sys
from pathlib import Path

from fastapi import FastAPI, Response, status
from pydantic import BaseModel
import uvicorn
import ctranslate2
from transformers import AutoTokenizer
import langid

# ------------------------------------------------------------------ #
# Accept the model directory as the first command-line argument
# ------------------------------------------------------------------ #
#    We only read sys.argv[1] – no argparse, no env-var, no fallback.
#    If the user forgets to give it, the script will crash with a nice
#    “index out of range” – that’s intentional: it tells the caller
#    that the argument is mandatory.
#
MODEL_DIR = Path(sys.argv[1]).expanduser().resolve()

# ----------------------------
# FastAPI setup
# ----------------------------

app = FastAPI(title="CTranslate2 Translation Server")

# ----------------------------
# Models
# ----------------------------

class TranslationRequest(BaseModel):
    text: str
    sourceLanguage: str | None = None
    targetLanguage: str

class TranslationResponse(BaseModel):
    translatedText: str
    detectedLanguage: str

class BatchTranslationRequest(BaseModel):
    text: str
    sourceLanguage: str | None = None
    targetLanguages: list[str]

class BatchTranslationResponse(BaseModel):
    detectedLanguage: str
    translations: dict[str, str]

# ----------------------------
# Load translation model & tokenizer
# ----------------------------

print(f"Loading CTranslate2 model from: {MODEL_DIR}")
tokenizer = AutoTokenizer.from_pretrained("facebook/m2m100_418M")
translator = ctranslate2.Translator(str(MODEL_DIR))
print("Model and tokenizer ready.")

# ----------------------------
# Health endpoint
# ----------------------------

@app.get("/health", status_code=status.HTTP_200_OK)
def health():
    return Response()

# ----------------------------
# Single translation endpoint
# ----------------------------

@app.post("/translate", response_model=TranslationResponse)
async def translate(req: TranslationRequest):
    text = req.text
    source_language = req.sourceLanguage
    target_language = req.targetLanguage

    if source_language is None or source_language.strip() == "" or source_language.lower() == "auto":
        detected_language, _ = langid.classify(text)
    else:
        detected_language = source_language
        tokenizer.src_lang = source_language

    tokens = tokenizer.convert_ids_to_tokens(tokenizer(text)["input_ids"])

    results = translator.translate_batch(
        [tokens],
        target_prefix=[[tokenizer.lang_code_to_token[target_language]]],
        beam_size=5
    )

    translated_text = tokenizer.convert_tokens_to_string(results[0].hypotheses[0][1:])

    return {
        "translatedText": translated_text,
        "detectedLanguage": detected_language,
    }

# ----------------------------
# Batch translation endpoint
# ----------------------------

@app.post("/translate/batch", response_model=BatchTranslationResponse)
async def translate_batch(req: BatchTranslationRequest):
    text = req.text
    target_languages = req.targetLanguages or []
    source_language = req.sourceLanguage

    if not target_languages:
        return {
            "detectedLanguage": "",
            "translations": {}
        }

    if source_language is None or source_language.strip() == "" or source_language.lower() == "auto":
        detected_language, _ = langid.classify(text)
    else:
        detected_language = source_language
        tokenizer.src_lang = detected_language

    tokens = tokenizer.convert_ids_to_tokens(tokenizer(text)["input_ids"])

    # One input sentence, many target prefixes
    target_prefixes = [
        [tokenizer.lang_code_to_token[lang]]
        for lang in target_languages
    ]

    results = translator.translate_batch(
        [tokens] * len(target_languages),
        target_prefix=target_prefixes,
        beam_size=5
    )

    translations: dict[str, str] = {}

    for lang, result in zip(target_languages, results):
        translated = tokenizer.convert_tokens_to_string(
            result.hypotheses[0][1:]
        )
        translations[lang] = translated

    return {
        "translations": translations,
        "detectedLanguage": detected_language
    }

# ----------------------------
# Run server
# ----------------------------

if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=5005, reload=False)
