"""Model management — loads and runs local GGUF models via llama.cpp."""
from __future__ import annotations

import logging
import threading
from pathlib import Path
from typing import Generator

from hivemind.config import ModelConfig

log = logging.getLogger(__name__)


class Model:
    """Wrapper around llama-cpp-python for local inference."""

    def __init__(self, config: ModelConfig):
        self.config = config
        self._llm = None
        # llama.cpp ist NICHT thread-sicher: immer nur ein Inferenz-Aufruf
        # gleichzeitig erlaubt. Das Lock serialisiert concurrent Aufrufe aus
        # Web-Chat, Telegram-Worker und anderen Quellen auf OS-Thread-Ebene.
        self._lock = threading.Lock()
        # Kann von außen gesetzt werden um die laufende Generierung abzubrechen.
        # Wird am Anfang jeder Generation (innerhalb des Locks) cleared.
        self._cancel_event = threading.Event()

    def request_cancel(self) -> None:
        """Laufende Generierung nach dem nächsten Token abbrechen."""
        self._cancel_event.set()

    def load(self) -> None:
        """Load the GGUF model into memory."""
        # CUDA Runtime DLLs auffindbar machen — falls sie im site-packages
        # (nvidia-cuda-runtime-cu1x) liegen statt in System32.
        import os, sys
        if sys.platform == "win32":
            try:
                import site
                sp_dirs = []
                try:
                    sp_dirs += site.getsitepackages()
                except Exception:
                    pass
                import sysconfig
                sp_dirs.append(sysconfig.get_path("purelib"))
                for sp in sp_dirs:
                    if not sp:
                        continue
                    from pathlib import Path as _P
                    # nvidia-cuda-runtime-cu12 legt DLLs unter nvidia/cuda_runtime/bin ab
                    for pattern in ["nvidia/*/bin", "nvidia/cuda_runtime/bin",
                                    "llama_cpp/lib"]:
                        for dll_dir in _P(sp).glob(pattern):
                            if dll_dir.is_dir():
                                try:
                                    os.add_dll_directory(str(dll_dir))
                                except Exception:
                                    pass
            except Exception:
                pass
        from llama_cpp import Llama

        path = Path(self.config.path)
        if not path.exists():
            raise FileNotFoundError(f"Model not found: {path}")

        log.info("Loading model: %s", path.name)
        self._llm = Llama(
            model_path=str(path),
            n_ctx=self.config.n_ctx,
            n_gpu_layers=self.config.n_gpu_layers,
            n_threads=self.config.n_threads or None,
            verbose=False,
        )
        log.info("Model loaded: %s (ctx=%d)", path.name, self.config.n_ctx)

    @property
    def loaded(self) -> bool:
        return self._llm is not None

    def generate(
        self,
        messages: list[dict],
        max_tokens: int = 1024,
        temperature: float = 0.7,
        stream: bool = False,
    ) -> str | Generator[str, None, None]:
        """Generate a response from chat messages.
        
        Args:
            messages: List of {"role": "system"|"user"|"assistant", "content": "..."}
            max_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            stream: If True, yield tokens as they're generated
        """
        if not self._llm:
            raise RuntimeError("Model not loaded. Call load() first.")

        messages = self._fit_to_context(messages, max_tokens)

        if stream:
            return self._stream(messages, max_tokens, temperature)

        # Intern als Stream laufen lassen — so kann _cancel_event zwischen
        # Tokens geprüft werden und eine neue Anfrage bricht die alte ab.
        parts: list[str] = []
        with self._lock:
            self._cancel_event.clear()   # Für diese Generierung zurücksetzen
            for chunk in self._llm.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                stream=True,
            ):
                if self._cancel_event.is_set():
                    log.debug("generate() cancelled after %d token(s)", len(parts))
                    break
                delta = chunk["choices"][0].get("delta", {})
                token = delta.get("content", "")
                if token:
                    parts.append(token)
        return "".join(parts)

    def _count_tokens(self, messages: list[dict]) -> int:
        """Schätze die Token-Anzahl aller Nachrichten (inkl. Overhead)."""
        try:
            total = 0
            for m in messages:
                # ~4 Overhead-Tokens pro Nachricht (role + separators)
                total += 4
                text = m.get("content") or ""
                total += len(self._llm.tokenize(text.encode("utf-8", errors="replace")))
            total += 3  # Antwort-Prolog
            return total
        except Exception:
            # Grobe Schätzung wenn tokenize() fehlschlägt: ~0.75 Tokens/Zeichen
            return sum(len(m.get("content", "")) * 3 // 4 + 5 for m in messages)

    def _fit_to_context(self, messages: list[dict], max_tokens: int) -> list[dict]:
        """Kürze den Chat-Verlauf so dass prompt_tokens + max_tokens <= n_ctx.

        Strategie:
        - System-Nachricht (Index 0) wird IMMER behalten.
        - Die letzte User-Nachricht wird IMMER behalten.
        - Älteste Verlaufs-Nachrichten werden zuerst entfernt.
        - Wenn selbst System + letzte User-Nachricht zu lang sind, wird der
          System-Prompt auf einen kurzen Stub reduziert.
        """
        budget = self.config.n_ctx - max_tokens - 64  # 64 Tokens Sicherheitspuffer

        if self._count_tokens(messages) <= budget:
            return messages

        # Trenne System-Prompt, Verlauf und letzte User-Nachricht
        system = [messages[0]] if messages and messages[0]["role"] == "system" else []
        rest = messages[len(system):]

        # Letzte User-Nachricht immer behalten
        last_user: list[dict] = []
        for i in range(len(rest) - 1, -1, -1):
            if rest[i]["role"] == "user":
                last_user = [rest[i]]
                rest = rest[:i] + rest[i+1:]
                break

        # Älteste Nachrichten entfernen bis es passt
        while rest and self._count_tokens(system + rest + last_user) > budget:
            rest.pop(0)

        trimmed = system + rest + last_user
        if self._count_tokens(trimmed) <= budget:
            if len(rest) < len(messages) - len(system) - len(last_user):
                log.debug(
                    "Context trim: %d → %d messages (budget=%d tokens)",
                    len(messages), len(trimmed), budget,
                )
            return trimmed

        # Notfall: System-Prompt kürzen
        if system:
            short_system = [{"role": "system", "content": "Du bist ein hilfreicher KI-Assistent."}]
            trimmed = short_system + last_user
            log.warning(
                "Context overflow: System-Prompt gekürzt, nur letzte Frage behalten."
            )
        return trimmed

    def _stream(
        self, messages: list[dict], max_tokens: int, temperature: float
    ) -> Generator[str, None, None]:
        """Stream tokens one by one — ebenfalls abbrechbar via _cancel_event."""
        messages = self._fit_to_context(messages, max_tokens)
        # Lock für die gesamte Streaming-Dauer halten (kein concurrent Zugriff)
        with self._lock:
            self._cancel_event.clear()
            for chunk in self._llm.create_chat_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                stream=True,
            ):
                if self._cancel_event.is_set():
                    log.debug("_stream() cancelled")
                    return
                delta = chunk["choices"][0].get("delta", {})
                token = delta.get("content", "")
                if token:
                    yield token

    def load_lora(self, lora_path: str) -> None:
        """Load a LoRA adapter on top of the base model."""
        if not self._llm:
            raise RuntimeError("Base model must be loaded first.")
        path = Path(lora_path)
        if not path.exists():
            raise FileNotFoundError(f"LoRA adapter not found: {path}")
        try:
            self._llm.load_lora(str(path))
            log.info("LoRA adapter loaded: %s", path.name)
        except AttributeError:
            log.warning("llama-cpp-python version doesn't support load_lora()")
        except Exception as e:
            log.error("Failed to load LoRA: %s", e)
            raise

    def generate_quick(self, messages: list[dict]) -> str:
        """Schnelle Kurzantwort (max 256 Tokens, temp 0.6) — für Telegram etc.

        Nutzt dieselbe Modell-Instanz wie generate(), nur mit kleinerem Token-Budget.
        Da dieselbe Llama-Instanz verwendet wird, gibt es keine VRAM-Konkurrenz.
        """
        return self.generate(messages, max_tokens=256, temperature=0.6)

    def embed(self, text: str) -> list[float]:
        """Get embedding vector for text (used for cache similarity)."""
        if not self._llm:
            raise RuntimeError("Model not loaded.")
        # llama.cpp can compute embeddings if model supports it
        # Fallback: use simple hash-based approach
        try:
            result = self._llm.embed(text)
            return result
        except Exception:
            # Fallback: not all models support embeddings
            log.warning("Model doesn't support embeddings, cache similarity disabled")
            return []
