"""RAG — Retrieval Augmented Generation.

Simple document store with TF-IDF-like search.
No external dependencies — pure Python + numpy.
"""
from __future__ import annotations

import hashlib
import json
import logging
import math
import re
import time
from pathlib import Path
from typing import Any

log = logging.getLogger(__name__)


def _tokenize(text: str) -> list[str]:
    """Simple whitespace + punctuation tokenizer."""
    return re.findall(r'\w+', text.lower())


def _compute_tf(tokens: list[str]) -> dict[str, float]:
    """Term frequency: count / total."""
    counts: dict[str, int] = {}
    for t in tokens:
        counts[t] = counts.get(t, 0) + 1
    total = len(tokens) or 1
    return {t: c / total for t, c in counts.items()}


class Chunk:
    """A chunk of text from a document."""
    __slots__ = ("doc_name", "index", "text", "tokens", "tf")

    def __init__(self, doc_name: str, index: int, text: str):
        self.doc_name = doc_name
        self.index = index
        self.text = text
        self.tokens = _tokenize(text)
        self.tf = _compute_tf(self.tokens)


class RAGStore:
    """Document store with TF-IDF search for RAG."""

    def __init__(self, data_dir: str | Path = "data/rag", chunk_size: int = 512):
        self.data_dir = Path(data_dir)
        self.docs_dir = self.data_dir / "documents"
        self.index_path = self.data_dir / "index.json"
        self.chunk_size = chunk_size

        self._chunks: list[Chunk] = []
        self._documents: dict[str, dict] = {}  # name → metadata
        self._idf: dict[str, float] = {}

        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.docs_dir.mkdir(parents=True, exist_ok=True)
        self._load_index()

    def _load_index(self):
        """Load document index from disk and rebuild chunks."""
        if not self.index_path.exists():
            return

        try:
            meta = json.loads(self.index_path.read_text(encoding="utf-8"))
            self._documents = meta.get("documents", {})

            # Rebuild chunks from stored documents
            for name, info in self._documents.items():
                doc_path = self.docs_dir / name
                if doc_path.exists():
                    text = doc_path.read_text(encoding="utf-8", errors="replace")
                    chunks = self._split_text(text, name)
                    self._chunks.extend(chunks)

            self._rebuild_idf()
            log.info("RAG loaded: %d documents, %d chunks",
                     len(self._documents), len(self._chunks))
        except Exception as e:
            log.error("Failed to load RAG index: %s", e)

    def _save_index(self):
        """Persist document metadata."""
        data = {"documents": self._documents}
        self.index_path.write_text(
            json.dumps(data, indent=2, ensure_ascii=False),
            encoding="utf-8",
        )

    def _split_text(self, text: str, doc_name: str) -> list[Chunk]:
        """Split text into overlapping chunks."""
        words = text.split()
        chunks = []
        step = max(1, self.chunk_size // 2)  # 50% overlap

        for i in range(0, len(words), step):
            chunk_words = words[i:i + self.chunk_size]
            if not chunk_words:
                break
            chunk_text = " ".join(chunk_words)
            chunks.append(Chunk(doc_name, len(chunks), chunk_text))

        return chunks

    def _rebuild_idf(self):
        """Recompute inverse document frequency across all chunks."""
        if not self._chunks:
            self._idf = {}
            return

        n = len(self._chunks)
        df: dict[str, int] = {}
        for chunk in self._chunks:
            seen = set(chunk.tokens)
            for t in seen:
                df[t] = df.get(t, 0) + 1

        self._idf = {
            t: math.log((n + 1) / (count + 1)) + 1
            for t, count in df.items()
        }

    def add_document(self, name: str, content: str | None = None,
                     file_path: str | Path | None = None) -> int:
        """Add a document to the store.

        Args:
            name: Document name (used as filename)
            content: Text content (if provided directly)
            file_path: Path to read content from

        Returns:
            Number of chunks created
        """
        if file_path:
            content = Path(file_path).read_text(encoding="utf-8", errors="replace")

        if not content:
            return 0

        # Save document
        doc_path = self.docs_dir / name
        doc_path.write_text(content, encoding="utf-8")

        # Remove old chunks for this document
        self._chunks = [c for c in self._chunks if c.doc_name != name]

        # Create new chunks
        new_chunks = self._split_text(content, name)
        self._chunks.extend(new_chunks)

        # Update metadata
        self._documents[name] = {
            "added": time.time(),
            "size": len(content),
            "chunks": len(new_chunks),
            "hash": hashlib.md5(content.encode()).hexdigest()[:12],
        }

        self._rebuild_idf()
        self._save_index()

        log.info("RAG: added '%s' (%d chunks)", name, len(new_chunks))
        return len(new_chunks)

    def remove_document(self, name: str) -> bool:
        """Remove a document from the store."""
        if name not in self._documents:
            return False

        self._chunks = [c for c in self._chunks if c.doc_name != name]
        del self._documents[name]

        doc_path = self.docs_dir / name
        if doc_path.exists():
            doc_path.unlink()

        self._rebuild_idf()
        self._save_index()
        log.info("RAG: removed '%s'", name)
        return True

    def search(self, query: str, top_k: int = 3) -> list[dict]:
        """Search for relevant chunks using TF-IDF similarity.

        Returns list of {text, doc_name, score}.
        """
        if not self._chunks:
            return []

        query_tokens = _tokenize(query)
        query_tf = _compute_tf(query_tokens)

        # Score each chunk
        results: list[tuple[float, Chunk]] = []
        for chunk in self._chunks:
            score = 0.0
            for term, qtf in query_tf.items():
                if term in chunk.tf:
                    idf = self._idf.get(term, 1.0)
                    score += qtf * chunk.tf[term] * idf * idf

            if score > 0:
                results.append((score, chunk))

        # Sort by score, return top_k
        results.sort(key=lambda x: x[0], reverse=True)
        return [
            {
                "text": chunk.text,
                "doc_name": chunk.doc_name,
                "score": round(score, 4),
            }
            for score, chunk in results[:top_k]
        ]

    def build_context(self, query: str, top_k: int = 3,
                      max_chars: int = 2000) -> str:
        """Build a context string from relevant chunks for RAG.

        Returns empty string if nothing relevant found.
        """
        results = self.search(query, top_k)
        if not results:
            return ""

        parts = []
        total = 0
        for r in results:
            text = r["text"]
            if total + len(text) > max_chars:
                text = text[:max_chars - total]
            parts.append(f"[{r['doc_name']}]: {text}")
            total += len(text)
            if total >= max_chars:
                break

        return "\n\n".join(parts)

    # ─── HuggingFace Dataset Import ─────────────────────────────

    def import_huggingface(self, url: str, max_rows: int = 5000,
                           text_fields: list[str] | None = None,
                           split: str = "train") -> dict:
        """Import a HuggingFace dataset into RAG via the datasets-server API.

        Args:
            url: HuggingFace dataset URL or repo id (e.g. 'user/dataset')
            max_rows: Maximum rows to import (default 5000)
            text_fields: Which fields to extract. Auto-detected if None.
            split: Dataset split to use (default 'train')

        Returns:
            dict with import stats
        """
        import re as _re
        from urllib.request import urlopen, Request
        from urllib.error import HTTPError

        # Parse repo_id from URL or plain id
        repo_id = url.strip().rstrip("/")
        m = _re.search(r'huggingface\.co/datasets/([^/]+/[^/?#]+)', repo_id)
        if m:
            repo_id = m.group(1)

        log.info("HF import: repo=%s split=%s max=%d", repo_id, split, max_rows)

        # 1. Get dataset info (features)
        info_url = f"https://datasets-server.huggingface.co/info?dataset={repo_id}"
        try:
            with urlopen(Request(info_url, headers={"User-Agent": "HiveMind/0.3"})) as resp:
                info = json.loads(resp.read())
        except HTTPError as e:
            return {"error": f"Dataset nicht gefunden: {e.code}", "repo": repo_id}

        # Find config name
        configs = info.get("dataset_info", {})
        config_name = list(configs.keys())[0] if configs else "default"
        features = configs.get(config_name, {}).get("features", {})
        split_info = configs.get(config_name, {}).get("splits", {})

        if split not in split_info:
            available = list(split_info.keys())
            return {"error": f"Split '{split}' nicht gefunden. Verfuegbar: {available}"}

        total_rows = split_info[split].get("num_examples", 0)
        fetch_rows = min(max_rows, total_rows)

        # Auto-detect text fields
        if not text_fields:
            text_fields = [k for k, v in features.items()
                           if isinstance(v, dict) and v.get("dtype") == "string"]
        if not text_fields:
            return {"error": "Keine Text-Felder im Dataset gefunden.", "features": list(features.keys())}

        # 2. Fetch rows in batches via rows API
        all_texts: list[str] = []
        batch_size = 100
        offset = 0

        while offset < fetch_rows:
            length = min(batch_size, fetch_rows - offset)
            rows_url = (
                f"https://datasets-server.huggingface.co/rows"
                f"?dataset={repo_id}&config={config_name}"
                f"&split={split}&offset={offset}&length={length}"
            )
            try:
                with urlopen(Request(rows_url, headers={"User-Agent": "HiveMind/0.3"})) as resp:
                    data = json.loads(resp.read())
            except HTTPError:
                break

            rows = data.get("rows", [])
            if not rows:
                break

            for row_obj in rows:
                row = row_obj.get("row", {})
                parts = []
                for field in text_fields:
                    val = row.get(field, "")
                    if val and isinstance(val, str):
                        parts.append(val.strip())
                if parts:
                    all_texts.append("\n".join(parts))

            offset += len(rows)
            if len(rows) < length:
                break

        if not all_texts:
            return {"error": "Keine Texte extrahiert.", "fields_tried": text_fields}

        # 3. Combine and add as single document
        doc_name = f"hf_{repo_id.replace('/', '_')}_{split}.txt"
        combined = "\n\n---\n\n".join(all_texts)
        chunks = self.add_document(doc_name, content=combined)

        result = {
            "success": True,
            "repo": repo_id,
            "split": split,
            "rows_imported": len(all_texts),
            "rows_total": total_rows,
            "text_fields": text_fields,
            "document": doc_name,
            "chunks": chunks,
            "size_mb": round(len(combined) / 1048576, 2),
        }
        log.info("HF import done: %d rows, %d chunks, %.1f MB",
                 len(all_texts), chunks, len(combined) / 1048576)
        return result

    @property
    def document_list(self) -> list[dict]:
        """List all indexed documents with metadata."""
        return [
            {"name": name, **info}
            for name, info in self._documents.items()
        ]

    @property
    def stats(self) -> dict:
        return {
            "documents": len(self._documents),
            "chunks": len(self._chunks),
            "vocabulary": len(self._idf),
        }
