"""Training data management and LoRA adapter support.

Handles:
- Collecting conversation data for future fine-tuning
- Exporting training data as JSONL
- Loading LoRA adapters into the model
- LoRA metadata for P2P sharing
- Article ingestion: URL → text extraction → Q&A generation → training data
"""
from __future__ import annotations

import asyncio
import json
import logging
import re
import time
from collections import Counter
from pathlib import Path
from typing import Any

log = logging.getLogger(__name__)


class TrainingManager:
    """Manages training data collection and LoRA adapters."""

    def __init__(self, data_dir: str | Path = "data/training",
                 lora_path: str = ""):
        self.data_dir = Path(data_dir)
        self.conversations_dir = self.data_dir / "conversations"
        self.lora_dir = self.data_dir / "lora"
        self.export_dir = self.data_dir / "exports"
        self.lora_path = lora_path

        for d in [self.data_dir, self.conversations_dir,
                  self.lora_dir, self.export_dir]:
            d.mkdir(parents=True, exist_ok=True)

        self._conversation_count = 0
        self._sample_count = 0
        self._refresh_counts()

    def _refresh_counts(self):
        """Count existing training data."""
        self._conversation_count = len(list(self.conversations_dir.glob("*.jsonl")))
        self._sample_count = 0
        for f in self.conversations_dir.glob("*.jsonl"):
            try:
                self._sample_count += sum(1 for _ in f.open())
            except Exception:
                pass

    def save_conversation(self, messages: list[dict],
                          metadata: dict | None = None) -> str:
        """Save a conversation as training data.

        Args:
            messages: List of {role, content} dicts
            metadata: Optional metadata (topic, quality, etc.)

        Returns:
            Filename of saved conversation
        """
        if len(messages) < 2:
            return ""

        timestamp = int(time.time())
        filename = f"conv_{timestamp}.jsonl"
        filepath = self.conversations_dir / filename

        with open(filepath, "w", encoding="utf-8") as f:
            entry = {
                "messages": messages,
                "timestamp": timestamp,
                "metadata": metadata or {},
            }
            f.write(json.dumps(entry, ensure_ascii=False) + "\n")

        self._conversation_count += 1
        self._sample_count += 1
        log.info("Saved conversation: %s (%d messages)", filename, len(messages))
        return filename

    def export_jsonl(self, format_type: str = "chatml") -> Path:
        """Export all conversations as a single JSONL file for training.

        Args:
            format_type: 'chatml' (OpenAI format) or 'alpaca'

        Returns:
            Path to exported file
        """
        timestamp = int(time.time())
        export_path = self.export_dir / f"training_{timestamp}.jsonl"

        count = 0
        with open(export_path, "w", encoding="utf-8") as out:
            for conv_file in sorted(self.conversations_dir.glob("*.jsonl")):
                try:
                    for line in conv_file.open(encoding="utf-8"):
                        data = json.loads(line.strip())
                        messages = data.get("messages", [])

                        if format_type == "chatml":
                            # OpenAI ChatML format
                            entry = {"messages": messages}
                        else:
                            # Alpaca format: extract instruction/input/output
                            user_msgs = [m["content"] for m in messages
                                         if m.get("role") == "user"]
                            asst_msgs = [m["content"] for m in messages
                                         if m.get("role") == "assistant"]
                            if user_msgs and asst_msgs:
                                entry = {
                                    "instruction": user_msgs[0],
                                    "input": "",
                                    "output": asst_msgs[0],
                                }
                            else:
                                continue

                        out.write(json.dumps(entry, ensure_ascii=False) + "\n")
                        count += 1
                except Exception as e:
                    log.warning("Error processing %s: %s", conv_file, e)

        log.info("Exported %d training samples to %s", count, export_path)
        return export_path

    def get_lora_adapters(self) -> list[dict]:
        """List available LoRA adapters."""
        adapters = []
        for f in self.lora_dir.glob("*.gguf"):
            adapters.append({
                "name": f.stem,
                "path": str(f),
                "size_mb": round(f.stat().st_size / (1024 * 1024), 1),
                "modified": f.stat().st_mtime,
            })
        # Also check configured lora_path
        if self.lora_path:
            lp = Path(self.lora_path)
            if lp.exists() and str(lp) not in [a["path"] for a in adapters]:
                adapters.append({
                    "name": lp.stem,
                    "path": str(lp),
                    "size_mb": round(lp.stat().st_size / (1024 * 1024), 1),
                    "modified": lp.stat().st_mtime,
                    "active": True,
                })
        return adapters

    def save_lora(self, name: str, data: bytes) -> Path:
        """Save an uploaded LoRA adapter."""
        if not name.endswith(".gguf"):
            name += ".gguf"
        path = self.lora_dir / name
        path.write_bytes(data)
        log.info("Saved LoRA adapter: %s (%.1f MB)", name,
                 len(data) / (1024 * 1024))
        return path

    def topic_analysis(self) -> dict:
        """Analyze topics in training data."""
        topics: Counter = Counter()
        total = 0

        for conv_file in self.conversations_dir.glob("*.jsonl"):
            try:
                for line in conv_file.open(encoding="utf-8"):
                    data = json.loads(line.strip())
                    messages = data.get("messages", [])
                    meta = data.get("metadata", {})

                    if "topic" in meta:
                        topics[meta["topic"]] += 1

                    # Simple topic extraction from user messages
                    for m in messages:
                        if m.get("role") == "user":
                            total += 1
                            content = m["content"].lower()
                            # Simple keyword-based topic detection
                            for keyword, topic in _TOPIC_KEYWORDS.items():
                                if keyword in content:
                                    topics[topic] += 1
                                    break
            except Exception:
                pass

        return {
            "total_samples": total,
            "topics": dict(topics.most_common(20)),
        }

    @property
    def stats(self) -> dict:
        self._refresh_counts()
        lora_adapters = self.get_lora_adapters()
        return {
            "conversations": self._conversation_count,
            "samples": self._sample_count,
            "lora_adapters": len(lora_adapters),
            "active_lora": self.lora_path or None,
            "export_count": len(list(self.export_dir.glob("*.jsonl"))),
        }

    def lora_metadata_for_sharing(self) -> dict | None:
        """Get metadata about active LoRA for P2P sharing."""
        if not self.lora_path:
            return None
        lp = Path(self.lora_path)
        if not lp.exists():
            return None
        return {
            "name": lp.stem,
            "size": lp.stat().st_size,
            "modified": lp.stat().st_mtime,
        }

    async def ingest_url(self, url: str, model: Any, n_pairs: int = 10) -> dict:
        """Fetch an article, generate Q&A training pairs, and save them.

        Pipeline:
          1. Download & extract readable text from *url* (httpx + BeautifulSoup).
          2. Prompt the local model to produce *n_pairs* FRAGE/ANTWORT pairs.
          3. Save each pair as a training conversation (JSONL).
          4. Return the result including *raw_text* so the caller can add to RAG.

        Args:
            url:      HTTP/HTTPS address of the article or documentation page.
            model:    Loaded Model instance used for Q&A generation.
            n_pairs:  Number of Q&A pairs to generate (clamped to 3–30).

        Returns:
            dict with keys: url, title, text_length, qa_pairs, saved_count, raw_text
        """
        title, text = await _fetch_article_text(url)
        if len(text) < 200:
            raise ValueError(
                "Zu wenig Text extrahiert (< 200 Zeichen). "
                "Prüfe die URL — JavaScript-lastige Seiten werden ggf. nicht unterstützt."
            )

        qa_pairs = await _generate_qa_pairs_async(text, model, n_pairs)
        if not qa_pairs:
            raise ValueError(
                "Keine Q&A-Paare generiert. "
                "Mögliche Ursachen: Modell-Timeout (> 180 s) oder unerwartetes Ausgabe-Format."
            )

        saved_files: list[str] = []
        for qa in qa_pairs:
            fname = self.save_conversation(
                [
                    {"role": "user", "content": qa["question"]},
                    {"role": "assistant", "content": qa["answer"]},
                ],
                metadata={"source": url, "source_title": title, "generated": True},
            )
            if fname:
                saved_files.append(fname)

        log.info("ingest_url: %s → %d Q&A-Paare gespeichert", url, len(saved_files))
        return {
            "url": url,
            "title": title,
            "text_length": len(text),
            "qa_pairs": qa_pairs,
            "saved_count": len(saved_files),
            "raw_text": text,
        }


# ──────────────────────────────────────────────────────────────────────────────
# Article ingestion helpers
# ──────────────────────────────────────────────────────────────────────────────

async def _fetch_article_text(url: str) -> tuple[str, str]:
    """Download *url* and extract clean, readable text.

    Uses httpx (already a project dependency) and BeautifulSoup.
    Returns (title, text) where text is at most 20 000 characters.
    """
    import httpx
    from bs4 import BeautifulSoup

    async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
        headers = {"User-Agent": "Mozilla/5.0 (compatible; HiveMind/1.0)"}
        resp = await client.get(url, headers=headers)
        resp.raise_for_status()
        html = resp.text

    soup = BeautifulSoup(html, "html.parser")

    # Remove noise elements
    for tag in soup(["script", "style", "nav", "header", "footer",
                     "aside", "form", "iframe", "noscript"]):
        tag.decompose()

    title_tag = soup.find("title")
    title = title_tag.get_text(strip=True)[:200] if title_tag else url

    # Try to find the main article container
    article = (
        soup.find("article")
        or soup.find("main")
        or soup.find(id=re.compile(r"content|article|main|post", re.I))
        or soup.find(class_=re.compile(r"content|article|post|entry|body", re.I))
        or soup.body
    )

    if article:
        paragraphs = article.find_all(["p", "h1", "h2", "h3", "h4", "li"])
        text = "\n".join(
            p.get_text(" ", strip=True)
            for p in paragraphs
            if len(p.get_text(strip=True)) > 30
        )
    else:
        text = soup.get_text(" ", strip=True)

    text = re.sub(r"\n{3,}", "\n\n", text).strip()
    return title, text[:20_000]


def _parse_qa_output(raw: str) -> list[dict]:
    """Parse FRAGE:/ANTWORT: formatted model output into list of dicts."""
    pairs: list[dict] = []
    current_q: str | None = None
    current_a: list[str] = []

    for line in raw.split("\n"):
        stripped = line.strip()
        upper = stripped.upper()
        if upper.startswith("FRAGE:"):
            if current_q and current_a:
                pairs.append({"question": current_q,
                               "answer": " ".join(current_a).strip()})
            current_q = stripped[6:].strip()
            current_a = []
        elif upper.startswith("ANTWORT:") and current_q is not None:
            current_a = [stripped[8:].strip()]
        elif current_a and stripped and not stripped.startswith("---"):
            current_a.append(stripped)

    if current_q and current_a:
        pairs.append({"question": current_q, "answer": " ".join(current_a).strip()})

    return [p for p in pairs if p["question"] and p["answer"]]


async def _generate_qa_pairs_async(
    text: str, model: Any, n_pairs: int
) -> list[dict]:
    """Prompt the local model to generate *n_pairs* Q&A pairs from *text*.

    Runs model.generate() in a thread-pool executor (blocking call).
    Times out after 180 s and returns an empty list on failure.
    """
    prompt = (
        f"Analysiere den folgenden Text und erstelle genau {n_pairs} "
        "Frage-Antwort-Paare.\n"
        "Die Fragen sollen das Wissen aus dem Text prüfen. "
        "Verwende AUSSCHLIESSLICH Informationen aus dem Text.\n\n"
        "Halte dich strikt an dieses Format — eine Zeile pro Feld, "
        "getrennt durch ---:\n"
        "FRAGE: [Frage hier]\n"
        "ANTWORT: [Antwort hier]\n"
        "---\n\n"
        f"Text:\n{text[:6000]}"
    )
    messages = [
        {
            "role": "system",
            "content": "Du bist ein Experte für das Erstellen von Trainingsdaten aus Fachtexten.",
        },
        {"role": "user", "content": prompt},
    ]
    loop = asyncio.get_event_loop()
    try:
        raw = await asyncio.wait_for(
            loop.run_in_executor(
                None,
                lambda: model.generate(messages, max_tokens=2500, temperature=0.2),
            ),
            timeout=180.0,
        )
        return _parse_qa_output(raw)
    except asyncio.TimeoutError:
        log.warning("Q&A-Generierung nach 180 s abgebrochen")
        return []
    except Exception as exc:
        log.error("Q&A-Generierung fehlgeschlagen: %s", exc)
        return []


# ──────────────────────────────────────────────────────────────────────────────
# Simple keyword → topic mapping for analysis
# ──────────────────────────────────────────────────────────────────────────────

# Simple keyword → topic mapping for analysis
_TOPIC_KEYWORDS = {
    "code": "Programmierung",
    "python": "Programmierung",
    "javascript": "Programmierung",
    "html": "Webentwicklung",
    "css": "Webentwicklung",
    "api": "Programmierung",
    "datenbank": "Datenbanken",
    "sql": "Datenbanken",
    "linux": "System/DevOps",
    "docker": "System/DevOps",
    "rezept": "Kochen",
    "kochen": "Kochen",
    "mathe": "Mathematik",
    "rechne": "Mathematik",
    "übersetze": "Sprache",
    "translate": "Sprache",
    "email": "Kommunikation",
    "brief": "Kommunikation",
    "zusammenfassung": "Text/Analyse",
    "erkläre": "Bildung",
    "explain": "Bildung",
}
