"""Relay Client — connects outbound to a relay server via TCP.

Works through any NAT/CGN/firewall since the connection is outbound.
No external dependencies — pure asyncio.
"""
from __future__ import annotations

import asyncio
import json
import logging
import time
from typing import Any, Callable

log = logging.getLogger(__name__)

INITIAL_RETRY = 5
MAX_RETRY = 120
PING_INTERVAL = 30


class RelayConnection:
    """Manages a TCP connection to a relay server."""

    def __init__(self, relay_url: str, node: Any,
                 on_message: Callable | None = None):
        self.relay_url = relay_url
        self.node = node
        self.on_message = on_message
        self._reader: asyncio.StreamReader | None = None
        self._writer: asyncio.StreamWriter | None = None
        self._running = False
        self._connected = False
        self._retry_delay = INITIAL_RETRY
        self._relay_nodes: list[dict] = []

    @property
    def connected(self) -> bool:
        return self._connected

    @property
    def relay_nodes(self) -> list[dict]:
        return list(self._relay_nodes)

    def _parse_url(self) -> tuple[str, int]:
        """Parse relay URL into (host, port)."""
        url = self.relay_url.strip()
        # Strip protocol prefix
        for prefix in ("relay://", "tcp://", "ws://", "wss://", "http://", "https://"):
            if url.startswith(prefix):
                url = url[len(prefix):]
                break
        # Strip path
        url = url.split("/")[0]
        # Parse host:port (IPv6 aware)
        if url.startswith("["):
            bracket_end = url.index("]")
            host = url[1:bracket_end]
            port_str = url[bracket_end + 2:] if len(url) > bracket_end + 1 else "9421"
        elif url.count(":") > 1:
            # bare IPv6 without port
            host = url
            port_str = "9421"
        else:
            parts = url.rsplit(":", 1)
            host = parts[0]
            port_str = parts[1] if len(parts) > 1 else "9421"
        return host, int(port_str)

    async def start(self):
        """Start the relay connection loop (auto-reconnect)."""
        self._running = True
        asyncio.create_task(self._connection_loop())

    async def stop(self):
        self._running = False
        self._connected = False
        if self._writer:
            try:
                self._writer.close()
                await self._writer.wait_closed()
            except Exception:
                pass

    async def _send(self, msg: dict) -> bool:
        if not self._connected or not self._writer:
            return False
        try:
            data = json.dumps(msg, ensure_ascii=False) + "\n"
            self._writer.write(data.encode())
            await self._writer.drain()
            return True
        except Exception as e:
            log.debug("Relay send failed: %s", e)
            self._connected = False
            return False

    async def send_to(self, target_id: str, payload: dict) -> bool:
        return await self._send({
            "type": "RELAY_SEND",
            "sender_id": self.node.id,
            "target_id": target_id,
            "payload": payload,
        })

    async def broadcast(self, payload: dict) -> bool:
        return await self._send({
            "type": "RELAY_BROADCAST",
            "sender_id": self.node.id,
            "payload": payload,
        })

    async def _connection_loop(self):
        while self._running:
            try:
                await self._connect_and_listen()
            except Exception as e:
                log.debug("Relay connection lost: %s", e)

            self._connected = False
            if not self._running:
                break

            log.info("Relay: reconnect in %ds...", self._retry_delay)
            await asyncio.sleep(self._retry_delay)
            self._retry_delay = min(self._retry_delay * 2, MAX_RETRY)

    async def _connect_and_listen(self):
        host, port = self._parse_url()
        log.info("Relay: connecting to %s:%d", host, port)

        self._reader, self._writer = await asyncio.wait_for(
            asyncio.open_connection(host, port), timeout=15
        )
        self._connected = True
        self._retry_delay = INITIAL_RETRY
        log.info("Relay: connected to %s:%d", host, port)

        # Send hello
        from pathlib import Path
        hello_payload = {
            "node_id": self.node.id,
            "name": self.node.name,
            "version": self.node.version,
            "capabilities": self.node.plugins.loaded if hasattr(self.node, 'plugins') else [],
        }
        if hasattr(self.node, 'config'):
            hello_payload["specialization"] = self.node.config.node.specialization
            hello_payload["expertise_tags"] = self.node.config.node.expertise_tags
            if self.node.config.model.path:
                hello_payload["model_name"] = Path(self.node.config.model.path).stem

        await self._send({
            "type": "RELAY_HELLO",
            "sender_id": self.node.id,
            "payload": hello_payload,
        })

        # Start ping task
        ping_task = asyncio.create_task(self._ping_loop())

        try:
            while self._running and self._connected:
                try:
                    line = await asyncio.wait_for(self._reader.readline(), timeout=90)
                except asyncio.TimeoutError:
                    # No data for 90s — connection might be dead
                    log.debug("Relay: read timeout, reconnecting...")
                    break

                if not line:
                    break

                try:
                    msg = json.loads(line)
                except json.JSONDecodeError:
                    continue

                msg_type = msg.get("type", "")

                if msg_type == "RELAY_PEERS":
                    self._relay_nodes = msg.get("payload", {}).get("nodes", [])
                    log.info("Relay: %d nodes online", len(self._relay_nodes))

                elif msg_type == "RELAY_NODE_JOINED":
                    payload = msg.get("payload", {})
                    self._relay_nodes = [n for n in self._relay_nodes if n.get("node_id") != payload.get("node_id")]
                    self._relay_nodes.append(payload)
                    log.info("Relay: node joined: %s", payload.get("name", "?"))

                elif msg_type == "RELAY_NODE_LEFT":
                    nid = msg.get("payload", {}).get("node_id", "")
                    name = msg.get("payload", {}).get("name", "?")
                    self._relay_nodes = [n for n in self._relay_nodes if n.get("node_id") != nid]
                    log.info("Relay: node left: %s", name)

                elif msg_type == "RELAY_MESSAGE":
                    sender_id = msg.get("sender_id", "")
                    sender_name = msg.get("sender_name", "")
                    payload = msg.get("payload", {})
                    if self.on_message:
                        try:
                            if asyncio.iscoroutinefunction(self.on_message):
                                await self.on_message(sender_id, sender_name, payload)
                            else:
                                self.on_message(sender_id, sender_name, payload)
                        except Exception as e:
                            log.error("Relay message handler error: %s", e)

                elif msg_type == "RELAY_PONG":
                    pass

                elif msg_type == "RELAY_ERROR":
                    log.warning("Relay error: %s", msg.get("payload", {}).get("error", ""))
        finally:
            ping_task.cancel()

    async def _ping_loop(self):
        try:
            while self._running and self._connected:
                await asyncio.sleep(PING_INTERVAL)
                await self._send({"type": "RELAY_PING"})
        except asyncio.CancelledError:
            pass

    @property
    def status(self) -> dict:
        return {
            "relay_url": self.relay_url,
            "connected": self._connected,
            "relay_nodes": len(self._relay_nodes),
            "nodes": self._relay_nodes,
        }
