"""Relay Client — connects to a relay server via HTTP polling.

Works through any NAT/CGN/firewall since all connections are outbound HTTP.
No external dependencies — pure urllib + asyncio.
"""
from __future__ import annotations

import asyncio
import json
import logging
import time
import urllib.request
import urllib.error
from pathlib import Path
from typing import Any, Callable

log = logging.getLogger(__name__)

POLL_INTERVAL = 3        # Poll every 3 seconds
HEARTBEAT_INTERVAL = 30  # Heartbeat every 30s
INITIAL_RETRY = 5
MAX_RETRY = 120


class RelayConnection:
    """HTTP-polling relay client."""

    def __init__(self, relay_url: str, node: Any,
                 on_message: Callable | None = None,
                 on_peers: Callable | None = None):
        self.relay_url = relay_url.rstrip("/")
        self.node = node
        self.on_message = on_message
        self.on_peers = on_peers
        self._running = False
        self._connected = False
        self._registered = False
        self._retry_delay = INITIAL_RETRY
        self._relay_nodes: list[dict] = []
        self._last_msg_id = 0

    @property
    def connected(self) -> bool:
        return self._connected

    @property
    def relay_nodes(self) -> list[dict]:
        return list(self._relay_nodes)

    def _api(self, endpoint: str, data: dict | None = None, method: str = "GET") -> dict | None:
        """Make HTTP request to relay. Blocking — run in executor."""
        url = f"{self.relay_url}/{endpoint.lstrip('/')}"
        for attempt in range(2):  # Retry once on connection reset
            try:
                if data is not None:
                    body = json.dumps(data, ensure_ascii=False).encode()
                    req = urllib.request.Request(
                        url, data=body, method="POST",
                        headers={
                            "Content-Type": "application/json",
                            "User-Agent": "HiveMind",
                            "Connection": "close",
                        }
                    )
                else:
                    req = urllib.request.Request(url, headers={
                        "User-Agent": "HiveMind",
                        "Connection": "close",
                    })

                with urllib.request.urlopen(req, timeout=10) as resp:
                    return json.loads(resp.read())
            except urllib.error.HTTPError as e:
                try:
                    body = json.loads(e.read())
                    log.debug("Relay API error %s: %s", endpoint, body.get("error", ""))
                except Exception:
                    log.debug("Relay API error %s: HTTP %d", endpoint, e.code)
                return None
            except (ConnectionResetError, OSError) as e:
                if attempt == 0:
                    continue  # Retry once
                log.debug("Relay API error %s: %s", endpoint, e)
                return None
            except Exception as e:
                log.debug("Relay API error %s: %s", endpoint, e)
                return None

    async def _api_async(self, endpoint: str, data: dict | None = None) -> dict | None:
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, lambda: self._api(endpoint, data))

    async def start(self):
        self._running = True
        asyncio.create_task(self._main_loop())

    async def stop(self):
        self._running = False
        self._connected = False

    async def reregister(self) -> bool:
        """Schickt die aktuellen Node-Infos sofort ans Relay (z.B. nach Profiländerung).

        Gibt True zurück wenn die Registrierung erfolgreich war.
        Hat keine Wirkung wenn der Client nicht verbunden ist.
        """
        if not self._connected:
            return False
        info = {
            "node_id": self.node.id,
            "name": self.node.name,
            "version": self.node.version,
            "port": self.node.config.network.listen_port if hasattr(self.node, "config") else 9420,
        }
        if hasattr(self.node, "config"):
            info["specialization"] = self.node.config.node.specialization
            info["expertise_tags"] = self.node.config.node.expertise_tags
            if self.node.config.model.path:
                info["model_name"] = Path(self.node.config.model.path).stem
        if hasattr(self.node, "network") and self.node.network:
            try:
                info["host"] = self.node.network._resolve_public_ip()
            except Exception:
                pass
        result = await self._api_async("/register", {
            "node_id": self.node.id,
            "name": self.node.name,
            "info": info,
        })
        if result is not None:
            log.info("Relay: re-registriert mit aktuellen Profildaten")
            return True
        log.warning("Relay: Re-Registrierung fehlgeschlagen")
        return False

    async def send_to(self, target_id: str, payload: dict) -> bool:
        if not self._connected:
            return False
        result = await self._api_async("/send", {
            "sender_id": self.node.id,
            "target_id": target_id,
            "payload": payload,
        })
        return result is not None and result.get("success", False)

    async def broadcast(self, payload: dict) -> bool:
        if not self._connected:
            return False
        result = await self._api_async("/broadcast", {
            "sender_id": self.node.id,
            "payload": payload,
        })
        return result is not None and result.get("success", False)

    async def _main_loop(self):
        while self._running:
            try:
                await self._run_session()
            except Exception as e:
                log.debug("Relay session error: %s", e)

            self._connected = False
            self._registered = False
            if not self._running:
                break

            log.info("Relay: retry in %ds...", self._retry_delay)
            await asyncio.sleep(self._retry_delay)
            self._retry_delay = min(self._retry_delay * 2, MAX_RETRY)

    async def _run_session(self):
        """Register, then poll loop."""
        # Register
        info = {
            "node_id": self.node.id,
            "name": self.node.name,
            "version": self.node.version,
            "port": self.node.config.network.listen_port if hasattr(self.node, 'config') else 9420,
        }
        if hasattr(self.node, 'config'):
            info["specialization"] = self.node.config.node.specialization
            info["expertise_tags"] = self.node.config.node.expertise_tags
            if self.node.config.model.path:
                info["model_name"] = Path(self.node.config.model.path).stem
        # Try to include public IP for direct P2P
        if hasattr(self.node, 'network') and self.node.network:
            try:
                info["host"] = self.node.network._resolve_public_ip()
            except Exception:
                pass

        result = await self._api_async("/register", {
            "node_id": self.node.id,
            "name": self.node.name,
            "info": info,
        })

        if result is None:
            raise ConnectionError("Registration failed")

        self._connected = True
        self._registered = True
        self._retry_delay = INITIAL_RETRY
        self._relay_nodes = result.get("nodes", [])
        log.info("Relay: connected to %s (%d nodes online)", self.relay_url, len(self._relay_nodes))

        # Notify about discovered peers for direct P2P attempts
        if self._relay_nodes and self.on_peers:
            try:
                if asyncio.iscoroutinefunction(self.on_peers):
                    await self.on_peers(self._relay_nodes)
                else:
                    self.on_peers(self._relay_nodes)
            except Exception:
                pass

        # Start heartbeat + poll
        heartbeat_task = asyncio.create_task(self._heartbeat_loop())
        try:
            await self._poll_loop()
        finally:
            heartbeat_task.cancel()

    async def _poll_loop(self):
        poll_errors = 0
        while self._running and self._connected:
            result = await self._api_async(
                f"/poll?node_id={self.node.id}&since_id={self._last_msg_id}"
            )

            if result is None:
                poll_errors += 1
                if poll_errors >= 5:
                    log.warning("Relay: too many poll errors, reconnecting...")
                    self._connected = False
                    return
                await asyncio.sleep(POLL_INTERVAL * 2)
                continue
            poll_errors = 0

            messages = result.get("messages", [])
            if result.get("last_id", 0) > self._last_msg_id:
                self._last_msg_id = result["last_id"]

            for msg in messages:
                sender_id = msg.get("sender_id", "")
                sender_name = msg.get("sender_name", "")
                payload = msg.get("payload", {})
                if self.on_message:
                    try:
                        if asyncio.iscoroutinefunction(self.on_message):
                            await self.on_message(sender_id, sender_name, payload)
                        else:
                            self.on_message(sender_id, sender_name, payload)
                    except Exception as e:
                        log.error("Relay message handler error: %s", e)

            await asyncio.sleep(POLL_INTERVAL)

    async def _heartbeat_loop(self):
        try:
            while self._running and self._connected:
                await asyncio.sleep(HEARTBEAT_INTERVAL)
                result = await self._api_async("/heartbeat", {
                    "node_id": self.node.id,
                })
                if result is None:
                    self._connected = False
                    return
                # Refresh node list periodically
                nodes_result = await self._api_async("/nodes")
                if nodes_result:
                    old_ids = {n.get("node_id") for n in self._relay_nodes}
                    self._relay_nodes = nodes_result.get("nodes", [])
                    new_ids = {n.get("node_id") for n in self._relay_nodes}
                    # If new nodes appeared, try direct P2P
                    if new_ids - old_ids and self.on_peers:
                        new_nodes = [n for n in self._relay_nodes if n.get("node_id") in (new_ids - old_ids)]
                        try:
                            if asyncio.iscoroutinefunction(self.on_peers):
                                await self.on_peers(new_nodes)
                            else:
                                self.on_peers(new_nodes)
                        except Exception:
                            pass
        except asyncio.CancelledError:
            pass

    @property
    def status(self) -> dict:
        return {
            "relay_url": self.relay_url,
            "connected": self._connected,
            "relay_nodes": len(self._relay_nodes),
            "nodes": self._relay_nodes,
        }
