"""Relay Server — TCP hub for nodes behind NAT/CGN.

Any HiveMind node with a public IP can run as relay.
Nodes connect outbound, relay routes messages between them.
Pure asyncio — no external dependencies.
"""
from __future__ import annotations

import asyncio
import json
import logging
import time

log = logging.getLogger(__name__)

DEFAULT_RELAY_PORT = 9421


class RelayNode:
    """A connected node on the relay."""
    __slots__ = ("reader", "writer", "node_id", "name", "info", "connected_at", "last_seen", "alive")

    def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        self.reader = reader
        self.writer = writer
        self.node_id = ""
        self.name = ""
        self.info: dict = {}
        self.connected_at = time.time()
        self.last_seen = time.time()
        self.alive = True

    async def send(self, msg: dict) -> bool:
        try:
            data = json.dumps(msg, ensure_ascii=False) + "\n"
            self.writer.write(data.encode())
            await self.writer.drain()
            return True
        except Exception:
            self.alive = False
            return False

    async def close(self):
        self.alive = False
        try:
            self.writer.close()
            await self.writer.wait_closed()
        except Exception:
            pass


class RelayHub:
    """TCP relay hub — routes messages between connected nodes."""

    def __init__(self, port: int = DEFAULT_RELAY_PORT):
        self.port = port
        self._nodes: dict[str, RelayNode] = {}  # node_id → RelayNode
        self._server: asyncio.Server | None = None

    @property
    def connected_count(self) -> int:
        return len(self._nodes)

    @property
    def node_list(self) -> list[dict]:
        return [
            {
                "node_id": n.node_id,
                "name": n.name,
                "connected_since": n.connected_at,
                "last_seen": n.last_seen,
                **{k: v for k, v in n.info.items() if k in (
                    "specialization", "expertise_tags", "model_name", "version"
                )},
            }
            for n in self._nodes.values()
        ]

    async def start(self):
        """Start the relay TCP server."""
        try:
            self._server = await asyncio.start_server(
                self._on_connection, "::", self.port
            )
            log.info("Relay hub listening on port %d (IPv4+IPv6)", self.port)
        except OSError:
            self._server = await asyncio.start_server(
                self._on_connection, "0.0.0.0", self.port
            )
            log.info("Relay hub listening on port %d (IPv4)", self.port)

    async def stop(self):
        for node in list(self._nodes.values()):
            await node.close()
        self._nodes.clear()
        if self._server:
            self._server.close()
            await self._server.wait_closed()

    async def _on_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        addr = writer.get_extra_info("peername")
        log.info("Relay: incoming connection from %s", addr)
        node = RelayNode(reader, writer)

        try:
            while node.alive:
                try:
                    line = await asyncio.wait_for(reader.readline(), timeout=120)
                except asyncio.TimeoutError:
                    continue
                if not line:
                    break

                try:
                    msg = json.loads(line)
                except json.JSONDecodeError:
                    continue

                node.last_seen = time.time()
                msg_type = msg.get("type", "")

                if msg_type == "RELAY_HELLO":
                    node.node_id = msg.get("sender_id", "")
                    node.name = msg.get("payload", {}).get("name", "")
                    node.info = msg.get("payload", {})
                    self._nodes[node.node_id] = node
                    log.info("Relay: node registered: %s (%s)", node.name, node.node_id[:12])

                    # Send peer list
                    await node.send({
                        "type": "RELAY_PEERS",
                        "payload": {"nodes": self.node_list},
                    })

                    # Notify others
                    await self._broadcast({
                        "type": "RELAY_NODE_JOINED",
                        "payload": {
                            "node_id": node.node_id,
                            "name": node.name,
                            **{k: v for k, v in node.info.items() if k in (
                                "specialization", "expertise_tags", "model_name", "version"
                            )},
                        },
                    }, exclude=node.node_id)

                elif msg_type == "RELAY_SEND":
                    target_id = msg.get("target_id", "")
                    inner = msg.get("payload", {})
                    target = self._nodes.get(target_id)
                    if target and target.alive:
                        await target.send({
                            "type": "RELAY_MESSAGE",
                            "sender_id": node.node_id,
                            "sender_name": node.name,
                            "payload": inner,
                        })
                    else:
                        await node.send({
                            "type": "RELAY_ERROR",
                            "payload": {"error": f"Node {target_id[:12]} nicht verbunden"},
                        })

                elif msg_type == "RELAY_BROADCAST":
                    inner = msg.get("payload", {})
                    await self._broadcast({
                        "type": "RELAY_MESSAGE",
                        "sender_id": node.node_id,
                        "sender_name": node.name,
                        "payload": inner,
                    }, exclude=node.node_id)

                elif msg_type == "RELAY_PING":
                    await node.send({"type": "RELAY_PONG"})

        except Exception as e:
            log.debug("Relay connection error: %s", e)
        finally:
            await node.close()
            if node.node_id:
                self._nodes.pop(node.node_id, None)
                log.info("Relay: node disconnected: %s (%s)", node.name, node.node_id[:12])
                await self._broadcast({
                    "type": "RELAY_NODE_LEFT",
                    "payload": {"node_id": node.node_id, "name": node.name},
                })

    async def _broadcast(self, msg: dict, exclude: str = ""):
        for nid, node in list(self._nodes.items()):
            if nid == exclude:
                continue
            if node.alive:
                await node.send(msg)

    @property
    def status(self) -> dict:
        return {
            "relay": True,
            "port": self.port,
            "connected_nodes": self.connected_count,
            "nodes": self.node_list,
        }
