"""HiveMind Node — the core unit that ties everything together."""
from __future__ import annotations

import asyncio
import logging
import uuid
from pathlib import Path

from hivemind import __version__
from hivemind.config import Config
from hivemind.model import Model
from hivemind.cache import SemanticCache
from hivemind.plugins import PluginManager
from hivemind.confidence import ConfidenceScorer, should_ask_network
from hivemind.rag import RAGStore
from hivemind.training import TrainingManager
from hivemind.sessions import SessionManager
from hivemind.memory import GlobalMemory

log = logging.getLogger(__name__)

NETWORK_QUERY_TIMEOUT = 15  # seconds


class Node:
    """A single HiveMind node — local AI + cache + plugins + P2P."""

    def __init__(self, config: Config, base_dir: Path | None = None):
        self.config = config
        self.base_dir = base_dir or Path(".")
        self.id = str(uuid.uuid4())[:8]
        self.name = config.node.name or f"node-{self.id}"
        self.version = __version__

        # Core components
        self.model = Model(config.model)
        self.cache = SemanticCache(config.cache, cache_dir=self.base_dir / "cache")
        self.plugins = PluginManager(node=self)
        self.network = None

        # Confidence scorer
        self.scorer = ConfidenceScorer(
            expertise_tags=config.node.expertise_tags,
            specialization=config.node.specialization,
        )

        # RAG store
        self.rag = RAGStore(
            data_dir=self.base_dir / config.rag.data_dir,
            chunk_size=config.rag.chunk_size,
        )

        # Training manager
        self.training = TrainingManager(
            data_dir=self.base_dir / config.training.data_dir,
            lora_path=config.training.lora_path,
        )

        # Session manager
        self.sessions = SessionManager(
            data_dir=self.base_dir / "data" / "sessions",
        )

        # Global memory
        self.memory = GlobalMemory(
            data_dir=self.base_dir / "data",
        )

        # Conversation history (points to active session)
        self.max_history = 50

        # Relay & Updates
        self.relay = None
        self.updater = None

        self._running = False
        self._pending_queries: dict[str, asyncio.Future] = {}

    async def start(self) -> None:
        """Initialize the node — load model, plugins, and network."""
        log.info("Starting node: %s (%s) v%s", self.name, self.id, self.version)

        # Load model if configured
        if self.config.model.path:
            try:
                self.model.load()
                # Load LoRA adapter if configured
                if self.config.training.lora_path:
                    try:
                        self.model.load_lora(self.config.training.lora_path)
                    except Exception as e:
                        log.warning("LoRA load failed: %s", e)
            except FileNotFoundError as e:
                log.error("Model not found: %s", e)
                log.info("Node will run without local model (cache + plugins only)")
        else:
            log.info("No model configured — running in plugin-only mode")

        # Load plugins
        await self.plugins.load(
            self.config.plugins.get("enabled", []),
            self.config.plugins.get("directory", "./plugins"),
        )

        # Start P2P network if enabled
        if self.config.network.enabled:
            from hivemind.network.peerlist import PeerList
            from hivemind.network.peer import P2PNetwork
            from hivemind.network.updater import AutoUpdater

            peer_list = PeerList(
                path=self.base_dir / "peers.json",
                own_id=self.id,
            )

            for addr in self.config.network.bootstrap_nodes:
                parts = addr.split(":")
                if len(parts) == 2:
                    peer_list.add_manual(parts[0], int(parts[1]))

            self.network = P2PNetwork(
                node=self,
                peer_list=peer_list,
                listen_port=self.config.network.listen_port,
            )

            # Setup auto-updater
            self.updater = AutoUpdater(self, self.base_dir)
            self.updater.register_handlers(self.network)

            # Register query handler
            from hivemind.network.protocol import MsgType
            self.network.on_message(MsgType.QUERY, self._handle_network_query)
            self.network.on_message(MsgType.RESPONSE, self._handle_network_response)

            await self.network.start()
            log.info("P2P network started on port %d", self.config.network.listen_port)

        # Start relay client if relay servers configured
        relay_urls = getattr(self.config.network, 'relay_servers', []) or []
        if relay_urls:
            from hivemind.network.relay_client import RelayConnection
            self.relay = RelayConnection(
                relay_url=relay_urls[0],
                node=self,
                on_message=self._handle_relay_message,
                on_peers=self._handle_relay_peers,
            )
            await self.relay.start()
            log.info("Relay client started: %s", relay_urls[0])

        self._running = True
        spec = f" | Spec: {self.config.node.specialization}" if self.config.node.specialization else ""
        rag_info = f" | RAG: {self.rag.stats['documents']} docs" if self.rag.stats['documents'] else ""
        log.info(
            "Node ready: %s | Model: %s | Plugins: %s | Cache: %d%s%s",
            self.name,
            "loaded" if self.model.loaded else "none",
            ", ".join(self.plugins.loaded) or "none",
            self.cache.size,
            spec,
            rag_info,
        )

    async def chat(self, user_message: str) -> str:
        """Process a user message with confidence-based routing.

        Flow:
        1. Check cache
        2. Build RAG context if available
        3. Generate local response
        4. Score confidence
        5. If low confidence + network available → ask peers
        6. Pick best response
        7. Save to cache + training data
        """
        # 1. Cache lookup
        cached = self.cache.lookup(user_message)
        if cached:
            log.info("Cache hit for: %s", user_message[:50])
            return cached

        # 2. Add to session history + extract memory facts
        self.sessions.add_message("user", user_message)
        self.memory.process_message("user", user_message)
        history = self.sessions.get_history()
        if len(history) > self.max_history:
            history = history[-self.max_history:]

        # 3. Build messages with RAG context
        messages = self._build_messages(user_message)

        # 4. Generate local response
        local_response = await self._generate_local(messages)

        # 5. Score confidence
        confidence = self.scorer.score(user_message, local_response)
        routing = should_ask_network(confidence)
        log.info("Confidence: %.2f → routing: %s", confidence, routing)

        # 6. Network routing
        best_response = local_response
        has_peers = (self.network and self.network.connected_count > 0) or (self.relay and self.relay.connected)
        if routing in ("both", "network") and has_peers:
            network_response = await self._ask_network(user_message)
            if network_response:
                net_text, net_confidence = network_response
                if routing == "network" or net_confidence > confidence:
                    best_response = net_text
                    log.info("Using network response (confidence: %.2f vs local %.2f)",
                             net_confidence, confidence)

        # 6b. Check for PDF export request
        pdf_plugin = self.plugins.get("pdf_export")
        if pdf_plugin:
            try:
                history_for_pdf = self.sessions.get_history() + [
                    {"role": "user", "content": user_message},
                    {"role": "assistant", "content": best_response},
                ]
                pdf_result = await pdf_plugin.capabilities[0].handler(messages=history_for_pdf)
                if pdf_result:
                    best_response = best_response + "\n\n" + pdf_result
            except Exception as e:
                log.debug("PDF plugin error: %s", e)

        # 7. Store results
        self.sessions.add_message("assistant", best_response)
        self.memory.process_message("assistant", best_response)
        self.cache.store(user_message, best_response)

        # Save conversation for training (every 5 exchanges)
        hist = self.sessions.get_history()
        if len(hist) >= 10 and len(hist) % 10 == 0:
            self.training.save_conversation(list(hist[-10:]))

        return best_response

    def _build_messages(self, user_message: str) -> list[dict]:
        """Build message list with system prompt, memory, RAG context, and history."""
        messages = []

        # System prompt with specialization
        system_parts = ["Du bist ein hilfreicher KI-Assistent."]
        if self.config.node.specialization:
            system_parts.append(
                f"Deine Spezialisierung: {self.config.node.specialization}"
            )
        if self.config.node.expertise_tags:
            system_parts.append(
                f"Deine Expertise: {', '.join(self.config.node.expertise_tags)}"
            )

        # Global memory
        memory_ctx = self.memory.build_context()
        if memory_ctx:
            system_parts.append(f"\n{memory_ctx}")

        # RAG context
        rag_context = self.rag.build_context(
            user_message,
            top_k=self.config.rag.top_k,
        )
        if rag_context:
            system_parts.append(
                f"\nRelevantes Wissen aus deinen Dokumenten:\n{rag_context}"
            )

        messages.append({"role": "system", "content": "\n".join(system_parts)})

        # Add conversation history from active session
        history = self.sessions.get_history()
        if len(history) > self.max_history:
            history = history[-self.max_history:]
        messages.extend(history)

        return messages

    async def _generate_local(self, messages: list[dict]) -> str:
        """Generate response using local model or plugins."""
        chat_plugin = self.plugins.get("chat")
        if chat_plugin:
            return await chat_plugin.capabilities[0].handler(messages=messages)
        elif self.model.loaded:
            return self.model.generate(messages)
        else:
            return (
                "⚠️ Kein Modell geladen und kein Chat-Plugin verfügbar.\n"
                "Konfiguriere model.path in config.yaml oder installiere ein Plugin."
            )

    async def _ask_network(self, query: str) -> tuple[str, float] | None:
        """Ask the P2P network for a response."""
        if not self.network:
            return None

        from hivemind.network.protocol import Message, MsgType
        import uuid as _uuid

        query_id = str(_uuid.uuid4())[:8]
        future: asyncio.Future = asyncio.get_event_loop().create_future()
        self._pending_queries[query_id] = future

        # Broadcast query via direct P2P
        msg = Message(
            type=MsgType.QUERY,
            sender_id=self.id,
            payload={
                "query_id": query_id,
                "query": query,
                "expertise_wanted": self.config.node.expertise_tags,
            },
        )
        sent = await self.network.broadcast(msg) if self.network else 0

        # Also broadcast via relay
        if self.relay and self.relay.connected:
            relay_ok = await self.relay.broadcast({
                "type": "QUERY",
                "query_id": query_id,
                "query": query,
                "expertise_wanted": self.config.node.expertise_tags,
            })
            if relay_ok:
                sent += 1
        log.info("Query sent to %d channels (id: %s)", sent, query_id)

        if sent == 0:
            del self._pending_queries[query_id]
            return None

        # Wait for best response
        try:
            result = await asyncio.wait_for(future, timeout=NETWORK_QUERY_TIMEOUT)
            return result
        except asyncio.TimeoutError:
            log.info("Network query timeout for %s", query_id)
            return self._pending_queries.pop(query_id, None)
        finally:
            self._pending_queries.pop(query_id, None)

    async def _handle_network_query(self, conn, msg):
        """Handle incoming query from a peer."""
        from hivemind.network.protocol import Message, MsgType

        query = msg.payload.get("query", "")
        query_id = msg.payload.get("query_id", "")

        if not query or not self.model.loaded:
            return

        # Generate local response
        messages = self._build_messages(query)
        response = await self._generate_local(messages)
        confidence = self.scorer.score(query, response)

        # Send response back
        await conn.send(Message(
            type=MsgType.RESPONSE,
            sender_id=self.id,
            payload={
                "query_id": query_id,
                "response": response,
                "confidence": confidence,
                "specialization": self.config.node.specialization,
            },
        ))

    async def _handle_network_response(self, conn, msg):
        """Handle response from a peer to our query."""
        query_id = msg.payload.get("query_id", "")
        response = msg.payload.get("response", "")
        confidence = msg.payload.get("confidence", 0.0)

        future = self._pending_queries.get(query_id)
        if future and not future.done():
            # Accept first good response, or best after timeout
            future.set_result((response, confidence))

    async def _handle_relay_peers(self, nodes: list[dict]):
        """Try direct P2P connections to nodes discovered via relay."""
        if not self.network:
            return
        for n in nodes:
            info = n.get("info", n)
            host = info.get("host", "")
            port = info.get("port", 0)
            node_id = n.get("node_id", info.get("node_id", ""))
            if not host or host == "0.0.0.0" or not port or node_id == self.id:
                continue
            # Don't retry if already connected
            from hivemind.network.protocol import PeerInfo
            addr = PeerInfo.format_address(host, port)
            if addr in self.network._connections and self.network._connections[addr].alive:
                continue
            log.info("Relay: trying direct P2P to %s (%s)", n.get("name", "?"), PeerInfo.format_address(host, port))
            asyncio.create_task(self.network.connect_to(host, port))

    async def _handle_relay_message(self, sender_id: str, sender_name: str, payload: dict):
        """Handle a message received via relay."""
        msg_type = payload.get("type", "")

        if msg_type == "QUERY":
            # Someone asks us a question via relay
            query = payload.get("query", "")
            query_id = payload.get("query_id", "")
            if not query or not self.model.loaded:
                return
            messages = self._build_messages(query)
            response = await self._generate_local(messages)
            confidence = self.scorer.score(query, response)
            if self.relay:
                await self.relay.send_to(sender_id, {
                    "type": "RESPONSE",
                    "query_id": query_id,
                    "response": response,
                    "confidence": confidence,
                    "specialization": self.config.node.specialization,
                })

        elif msg_type == "RESPONSE":
            # Response to our query via relay
            query_id = payload.get("query_id", "")
            response = payload.get("response", "")
            confidence = payload.get("confidence", 0.0)
            future = self._pending_queries.get(query_id)
            if future and not future.done():
                future.set_result((response, confidence))

        elif msg_type == "UPDATE_DATA":
            # Update received via relay — delegate to updater
            if self.updater:
                from hivemind.network.protocol import Message, MsgType
                fake_msg = Message(
                    type=MsgType.UPDATE_DATA,
                    sender_id=sender_id,
                    payload={
                        "manifest": payload.get("manifest", {}),
                        "data_b64": payload.get("data_b64", ""),
                    },
                )
                await self.updater._handle_update_data(None, fake_msg)

    async def stop(self) -> None:
        """Shutdown the node gracefully."""
        log.info("Stopping node: %s", self.name)

        # Save remaining conversation data
        hist = self.sessions.get_history()
        if len(hist) >= 4:
            self.training.save_conversation(list(hist))
        self.sessions.save_active()

        if self.relay:
            await self.relay.stop()
        if self.network:
            await self.network.stop()
        await self.plugins.shutdown_all()
        self._running = False

    @property
    def status(self) -> dict:
        s = {
            "id": self.id,
            "name": self.name,
            "version": self.version,
            "model_loaded": self.model.loaded,
            "plugins": self.plugins.loaded,
            "cache_size": self.cache.size,
            "history_length": len(self.sessions.get_history()),
            "session": self.sessions.active.summary_info,
            "running": self._running,
            "specialization": self.config.node.specialization,
            "expertise_tags": self.config.node.expertise_tags,
            "rag": self.rag.stats,
            "training": self.training.stats,
            "memory": self.memory.stats,
        }
        if self.network:
            s["network"] = self.network.status
        if self.relay:
            s["relay"] = self.relay.status
        return s
