"""P2P connection manager — handles connections to other nodes."""
from __future__ import annotations

import asyncio
import logging
import time
from pathlib import Path
from typing import Any, Callable

from hivemind.network.protocol import Message, MsgType, PeerInfo
from hivemind.network.peerlist import PeerList

log = logging.getLogger(__name__)

PING_INTERVAL = 60       # Ping peers every 60 seconds
CLEANUP_INTERVAL = 3600  # Cleanup dead peers every hour
SYNC_INTERVAL = 300      # Sync peer lists every 5 minutes


class PeerConnection:
    """A single connection to a remote peer."""

    def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter,
                 address: str, is_outbound: bool = True):
        self.reader = reader
        self.writer = writer
        self.address = address
        self.is_outbound = is_outbound
        self.peer_info: PeerInfo | None = None
        self.connected_at = time.time()
        self._alive = True

    async def send(self, msg: Message) -> bool:
        """Send a message to this peer."""
        try:
            self.writer.write(msg.encode())
            await self.writer.drain()
            return True
        except (ConnectionError, OSError) as e:
            log.debug("Send failed to %s: %s", self.address, e)
            self._alive = False
            return False

    async def recv(self) -> Message | None:
        """Receive a message from this peer."""
        try:
            data = await asyncio.wait_for(self.reader.readline(), timeout=30)
            if not data:
                self._alive = False
                return None
            return Message.decode(data)
        except asyncio.TimeoutError:
            return None
        except (ConnectionError, OSError, json.JSONDecodeError) as e:
            log.debug("Recv failed from %s: %s", self.address, e)
            self._alive = False
            return None

    async def close(self):
        try:
            self.writer.close()
            await self.writer.wait_closed()
        except Exception:
            pass
        self._alive = False

    @property
    def alive(self) -> bool:
        return self._alive and not self.writer.is_closing()


import json


class P2PNetwork:
    """Manages all P2P connections, discovery, and sync."""

    def __init__(self, node: Any, peer_list: PeerList, listen_port: int = 9420):
        self.node = node
        self.peers = peer_list
        self.listen_port = listen_port
        self._connections: dict[str, PeerConnection] = {}
        self._server: asyncio.Server | None = None
        self._running = False
        self._message_handlers: dict[str, Callable] = {}
        self._fail_count: dict[str, int] = {}  # address → consecutive failures

        # Register default handlers
        self._register_defaults()

    def _register_defaults(self):
        """Register built-in message handlers."""
        self._message_handlers[MsgType.HELLO] = self._handle_hello
        self._message_handlers[MsgType.PING] = self._handle_ping
        self._message_handlers[MsgType.PONG] = self._handle_pong
        self._message_handlers[MsgType.PEER_LIST] = self._handle_peer_list
        self._message_handlers[MsgType.VERSION_INFO] = self._handle_version_info

    def on_message(self, msg_type: str, handler: Callable):
        """Register a custom message handler."""
        self._message_handlers[msg_type] = handler

    # ─── Server ──────────────────────────────────────────────────────

    async def start(self) -> None:
        """Start listening for connections and connect to known peers."""
        self._running = True

        # Start TCP server — dual-stack (IPv4 + IPv6)
        try:
            self._server = await asyncio.start_server(
                self._on_incoming, "::", self.listen_port
            )
            log.info("P2P listening on port %d (IPv4+IPv6)", self.listen_port)
        except OSError:
            # Fallback: IPv4 only
            self._server = await asyncio.start_server(
                self._on_incoming, "0.0.0.0", self.listen_port
            )
            log.info("P2P listening on port %d (IPv4 only)", self.listen_port)

        # Start background tasks
        asyncio.create_task(self._ping_loop())
        asyncio.create_task(self._sync_loop())
        asyncio.create_task(self._cleanup_loop())

        # Connect to known peers
        asyncio.create_task(self._connect_to_known_peers())

    async def stop(self) -> None:
        """Shutdown all connections."""
        self._running = False
        for conn in list(self._connections.values()):
            await conn.close()
        self._connections.clear()
        if self._server:
            self._server.close()
            await self._server.wait_closed()
        log.info("P2P network stopped")

    # ─── Incoming connections ────────────────────────────────────────

    async def _on_incoming(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        """Handle a new incoming connection."""
        addr = writer.get_extra_info("peername")
        address = PeerInfo.format_address(addr[0], addr[1]) if addr else "unknown"
        log.info("Incoming connection from %s", address)

        conn = PeerConnection(reader, writer, address, is_outbound=False)

        # Send our hello
        await conn.send(Message(
            type=MsgType.HELLO,
            sender_id=self.node.id,
            payload=self._our_info(),
        ))

        # Handle messages
        asyncio.create_task(self._connection_loop(conn))

    # ─── Outgoing connections ────────────────────────────────────────

    async def connect_to(self, host: str, port: int) -> bool:
        """Connect to a specific peer. Uses exponential backoff on repeated failures."""
        address = PeerInfo.format_address(host, port)
        if address in self._connections and self._connections[address].alive:
            return True  # Already connected

        # Exponential backoff: skip if too many recent failures
        fails = self._fail_count.get(address, 0)
        if fails > 0:
            # After 1 fail: wait 2 cycles, after 2: 4, after 3: 8, max 30 (=30 min)
            skip_cycles = min(2 ** fails, 30)
            # We track this via fail count; caller (ping_loop) runs every 60s
            # So we decrement on each call and only actually try when it hits 0
            self._fail_count[address] = fails - 1
            if fails > 1:
                return False  # Still backing off

        try:
            reader, writer = await asyncio.wait_for(
                asyncio.open_connection(host, port), timeout=10
            )
            conn = PeerConnection(reader, writer, address, is_outbound=True)

            # Send hello
            await conn.send(Message(
                type=MsgType.HELLO,
                sender_id=self.node.id,
                payload=self._our_info(),
            ))

            self._connections[address] = conn
            self.peers.mark_online(address)
            self._fail_count.pop(address, None)  # Reset on success
            log.info("Connected to %s", address)

            # Handle messages
            asyncio.create_task(self._connection_loop(conn))
            return True

        except (ConnectionError, OSError, asyncio.TimeoutError) as e:
            fails = self._fail_count.get(address, 0)
            new_fails = min(fails + 2, 30)  # +2 because we decremented above
            self._fail_count[address] = new_fails
            backoff_min = min(2 ** (new_fails - 1), 30)
            log.debug("Failed to connect to %s: %s (retry in ~%d min)", address, e, backoff_min)
            self.peers.mark_offline(address)
            return False

    async def _connect_to_known_peers(self):
        """Try to connect to all known peers."""
        for peer in self.peers.all_peers:
            if not self._running:
                break
            await self.connect_to(peer.host, peer.port)
            await asyncio.sleep(0.5)  # Don't spam

    # ─── Connection loop ─────────────────────────────────────────────

    async def _connection_loop(self, conn: PeerConnection):
        """Read messages from a connection until it dies."""
        try:
            while self._running and conn.alive:
                msg = await conn.recv()
                if msg is None:
                    if not conn.alive:
                        break
                    continue
                await self._dispatch(conn, msg)
        except Exception as e:
            log.error("Connection loop error for %s: %s", conn.address, e)
        finally:
            await conn.close()
            self._connections.pop(conn.address, None)
            if conn.peer_info:
                self.peers.mark_offline(conn.peer_info.address)
            log.info("Disconnected: %s", conn.address)

    async def _dispatch(self, conn: PeerConnection, msg: Message):
        """Route a message to the appropriate handler."""
        handler = self._message_handlers.get(msg.type)
        if handler:
            try:
                await handler(conn, msg)
            except Exception as e:
                log.error("Handler error for %s: %s", msg.type, e)
        else:
            log.debug("Unknown message type: %s", msg.type)

    # ─── Message handlers ────────────────────────────────────────────

    async def _handle_hello(self, conn: PeerConnection, msg: Message):
        """Handle hello/handshake from a peer."""
        payload = msg.payload
        # Use host from payload if it's a real IP, otherwise use TCP peer address
        announced_host = payload.get("host", "")
        if not announced_host or announced_host == "0.0.0.0":
            announced_host = PeerInfo.parse_address(conn.address)[0] if conn.address != "unknown" else ""
        peer = PeerInfo(
            node_id=msg.sender_id,
            host=announced_host,
            port=payload.get("port", self.listen_port),
            name=payload.get("name", ""),
            version=payload.get("version", ""),
            last_seen=time.time(),
            online=True,
            capabilities=payload.get("capabilities", []),
        )
        conn.peer_info = peer
        actual_address = peer.address
        self._connections[actual_address] = conn
        self.peers.add(peer)
        self.peers.mark_online(actual_address)

        # Send welcome + our peer list
        await conn.send(Message(
            type=MsgType.WELCOME,
            sender_id=self.node.id,
            payload=self._our_info(),
        ))
        await conn.send(Message(
            type=MsgType.PEER_LIST,
            sender_id=self.node.id,
            payload={"peers": self.peers.to_list()},
        ))
        # Send version info
        await conn.send(Message(
            type=MsgType.VERSION_INFO,
            sender_id=self.node.id,
            payload={"version": self.node.version},
        ))

        log.info("Handshake complete: %s (%s)", peer.name, actual_address)

    async def _handle_ping(self, conn: PeerConnection, msg: Message):
        await conn.send(Message(
            type=MsgType.PONG,
            sender_id=self.node.id,
        ))
        if conn.peer_info:
            self.peers.mark_online(conn.peer_info.address)

    async def _handle_pong(self, conn: PeerConnection, msg: Message):
        if conn.peer_info:
            self.peers.mark_online(conn.peer_info.address)

    async def _handle_peer_list(self, conn: PeerConnection, msg: Message):
        """Merge received peer list with ours."""
        remote_peers = msg.payload.get("peers", [])
        new_count = self.peers.merge(remote_peers)
        if new_count > 0:
            log.info("Merged %d new peers from %s", new_count, conn.address)
            # Try connecting to new peers
            asyncio.create_task(self._connect_to_known_peers())

    async def _handle_version_info(self, conn: PeerConnection, msg: Message):
        """Handle version info from peer — check if update available."""
        remote_version = msg.payload.get("version", "")
        # Delegate to updater if registered
        handler = self._message_handlers.get("_version_check")
        if handler:
            await handler(conn, msg)

    # ─── Background tasks ────────────────────────────────────────────

    async def _ping_loop(self):
        """Periodically ping all known peers."""
        while self._running:
            await asyncio.sleep(PING_INTERVAL)
            for peer in self.peers.all_peers:
                if peer.address in self._connections:
                    conn = self._connections[peer.address]
                    if conn.alive:
                        await conn.send(Message(
                            type=MsgType.PING,
                            sender_id=self.node.id,
                        ))
                    else:
                        self.peers.mark_offline(peer.address)
                else:
                    # Try to reconnect
                    await self.connect_to(peer.host, peer.port)

    async def _sync_loop(self):
        """Periodically share peer lists with connected peers."""
        while self._running:
            await asyncio.sleep(SYNC_INTERVAL)
            peer_list_msg = Message(
                type=MsgType.PEER_LIST,
                sender_id=self.node.id,
                payload={"peers": self.peers.to_list()},
            )
            for conn in list(self._connections.values()):
                if conn.alive:
                    await conn.send(peer_list_msg)

    async def _cleanup_loop(self):
        """Periodically remove dead peers."""
        while self._running:
            await asyncio.sleep(CLEANUP_INTERVAL)
            removed = self.peers.cleanup_dead()
            if removed:
                log.info("Cleaned up %d dead peers", len(removed))

    # ─── Helpers ─────────────────────────────────────────────────────

    def _resolve_public_ip(self) -> str:
        """Try to determine our public IP. Prefers IPv6 (DS-Lite compatible)."""
        if hasattr(self, '_cached_ip'):
            return self._cached_ip
        import urllib.request
        # Try IPv6 first (works through DS-Lite/CGN), then IPv4
        for url in ("https://api6.ipify.org", "https://api4.ipify.org", "https://api.ipify.org"):
            try:
                req = urllib.request.Request(url, headers={"User-Agent": "HiveMind"})
                with urllib.request.urlopen(req, timeout=5) as resp:
                    ip = resp.read().decode().strip()
                    if ip:
                        self._cached_ip = ip
                        return ip
            except Exception:
                continue
        self._cached_ip = "0.0.0.0"
        return "0.0.0.0"

    def _our_info(self) -> dict:
        """Our node info for handshake."""
        info = {
            "node_id": self.node.id,
            "name": self.node.name,
            "host": self._resolve_public_ip(),
            "port": self.listen_port,
            "version": self.node.version,
            "capabilities": self.node.plugins.loaded if hasattr(self.node, 'plugins') else [],
        }
        if hasattr(self.node, 'config'):
            info["specialization"] = self.node.config.node.specialization
            info["expertise_tags"] = self.node.config.node.expertise_tags
            if self.node.config.model.path:
                info["model_name"] = Path(self.node.config.model.path).stem
        return info

    async def broadcast(self, msg: Message) -> int:
        """Send a message to all connected peers. Returns success count."""
        sent = 0
        for conn in list(self._connections.values()):
            if conn.alive:
                if await conn.send(msg):
                    sent += 1
        return sent

    @property
    def connected_count(self) -> int:
        return sum(1 for c in self._connections.values() if c.alive)

    @property
    def status(self) -> dict:
        return {
            "listening": self.listen_port,
            "connected": self.connected_count,
            "known_peers": len(self.peers),
            "online_peers": len(self.peers.online_peers),
            "peers": self.peers.status_summary(),
        }
