"""Peer list management — persistent address book with auto-cleanup."""
from __future__ import annotations

import json
import logging
import time
from pathlib import Path

from hivemind.network.protocol import PeerInfo

log = logging.getLogger(__name__)

DEAD_HOURS = 168  # 7 days


class PeerList:
    """Manages the list of known peers with persistence and cleanup.
    
    - Peers not seen for 168h are removed (dead)
    - List is synced with other peers on connect
    - Manual additions are supported
    - Persisted to disk as JSON
    """

    def __init__(self, path: Path, own_id: str = ""):
        self.path = path
        self.own_id = own_id
        self._peers: dict[str, PeerInfo] = {}  # key = "host:port"
        self._load()

    def add(self, peer: PeerInfo) -> bool:
        """Add or update a peer. Returns True if new."""
        key = peer.address
        if peer.node_id == self.own_id:
            return False  # Don't add ourselves

        existing = self._peers.get(key)
        if existing:
            # Update: keep first_seen, update last_seen if newer
            if peer.last_seen > existing.last_seen:
                existing.last_seen = peer.last_seen
                existing.online = peer.online
                existing.name = peer.name or existing.name
                existing.version = peer.version or existing.version
                existing.node_id = peer.node_id or existing.node_id
                existing.capabilities = peer.capabilities or existing.capabilities
                # Sync routing-relevant fields — always take the freshest value
                if peer.specialization:
                    existing.specialization = peer.specialization
                if peer.expertise_tags:
                    existing.expertise_tags = peer.expertise_tags
                if peer.model_name:
                    existing.model_name = peer.model_name
            self._save()
            return False
        else:
            peer.first_seen = time.time()
            self._peers[key] = peer
            log.info("New peer added: %s (%s)", peer.name or peer.node_id, key)
            self._save()
            return True

    def add_manual(self, host: str, port: int) -> PeerInfo:
        """Manually add a peer by address."""
        peer = PeerInfo(
            node_id="",
            host=host,
            port=port,
            name=f"manual-{host}",
            last_seen=time.time(),
        )
        self.add(peer)
        return peer

    def remove(self, address: str) -> bool:
        """Remove a peer by address."""
        if address in self._peers:
            del self._peers[address]
            self._save()
            return True
        return False

    def mark_online(self, address: str) -> None:
        """Mark a peer as online (just responded to ping)."""
        peer = self._peers.get(address)
        if peer:
            peer.online = True
            peer.last_seen = time.time()
            self._save()

    def mark_offline(self, address: str) -> None:
        """Mark a peer as offline (failed to connect/ping)."""
        peer = self._peers.get(address)
        if peer:
            peer.online = False
            # Don't update last_seen — we want to track when it was LAST online

    def get(self, address: str) -> PeerInfo | None:
        return self._peers.get(address)

    def cleanup_dead(self) -> list[str]:
        """Remove peers not seen for 168+ hours. Returns removed addresses."""
        dead = [addr for addr, p in self._peers.items() if p.is_dead]
        for addr in dead:
            name = self._peers[addr].name or self._peers[addr].node_id
            log.info("Removing dead peer: %s (%s) — not seen for %.0fh",
                     name, addr, self._peers[addr].hours_since_seen)
            del self._peers[addr]
        if dead:
            self._save()
        return dead

    def merge(self, remote_peers: list[dict]) -> int:
        """Merge a peer list from another node. Returns count of new peers."""
        new_count = 0
        for pdata in remote_peers:
            try:
                peer = PeerInfo.from_dict(pdata)
                if self.add(peer):
                    new_count += 1
            except Exception as e:
                log.warning("Failed to merge peer: %s", e)
        return new_count

    def to_list(self) -> list[dict]:
        """Export all peers as list of dicts (for network sync)."""
        return [p.to_dict() for p in self._peers.values()]

    @property
    def online_peers(self) -> list[PeerInfo]:
        return [p for p in self._peers.values() if p.online]

    @property
    def all_peers(self) -> list[PeerInfo]:
        return list(self._peers.values())

    @property
    def addresses(self) -> list[str]:
        return list(self._peers.keys())

    def __len__(self) -> int:
        return len(self._peers)

    def _load(self) -> None:
        """Load peer list from disk."""
        if self.path.exists():
            try:
                with open(self.path, encoding="utf-8") as f:
                    data = json.load(f)
                for pdata in data:
                    peer = PeerInfo.from_dict(pdata)
                    self._peers[peer.address] = peer
                log.info("Peer list loaded: %d peers", len(self._peers))
            except Exception as e:
                log.warning("Failed to load peer list: %s", e)

    def _save(self) -> None:
        """Persist peer list to disk."""
        self.path.parent.mkdir(parents=True, exist_ok=True)
        with open(self.path, "w", encoding="utf-8") as f:
            json.dump(self.to_list(), f, ensure_ascii=False, indent=2)

    def find_by_topics(self, topics: list[str],
                         min_score: float = 0.15) -> list[tuple]:
        """Find online peers that match the given topics, sorted by score descending.

        Returns list of (PeerInfo, score) tuples.
        Peers with no declared specialization are returned at the end with score 0.5
        only if no better-matching peers exist (fallback pool).
        """
        from hivemind.topic import score_peer

        scored: list[tuple] = []
        neutral: list[tuple] = []

        for peer in self.online_peers:
            s = score_peer(topics, peer.specialization, peer.expertise_tags)
            if peer.specialization or peer.expertise_tags:
                if s >= min_score:
                    scored.append((peer, s))
            else:
                neutral.append((peer, 0.5))  # no profile → neutral fallback

        scored.sort(key=lambda x: x[1], reverse=True)

        # If we have at least one specialized match, skip neutrals
        return scored if scored else neutral

    def status_summary(self) -> str:
        """Human-readable status."""
        total = len(self._peers)
        online = len(self.online_peers)
        return f"{online}/{total} peers online"
