"""Network protocol — message format for P2P communication."""
from __future__ import annotations

import json
import time
from dataclasses import dataclass, field, asdict
from enum import Enum
from typing import Any


class MsgType(str, Enum):
    """Message types in the HiveMind protocol."""
    # Handshake
    HELLO = "hello"              # Initial greeting with node info
    WELCOME = "welcome"          # Response to hello

    # Peer discovery
    PEER_LIST = "peer_list"      # Share/request peer list
    PING = "ping"                # Keepalive ping
    PONG = "pong"                # Keepalive response

    # Chat / inference
    QUERY = "query"              # Ask network for a response
    RESPONSE = "response"        # Response to query

    # Updates
    VERSION_INFO = "version_info"    # Announce version info
    UPDATE_AVAILABLE = "update_avail"  # New update available
    UPDATE_REQUEST = "update_req"    # Request update data
    UPDATE_DATA = "update_data"      # Update payload (signed)

    # Generic
    ERROR = "error"


@dataclass
class PeerInfo:
    """Information about a peer node."""
    node_id: str
    host: str
    port: int
    name: str = ""
    version: str = "0.1.0"
    last_seen: float = 0.0
    first_seen: float = 0.0
    online: bool = False
    capabilities: list[str] = field(default_factory=list)
    specialization: str = ""
    expertise_tags: list[str] = field(default_factory=list)
    model_name: str = ""
    model_size: str = ""
    fitness_score: float = 0.5

    def __post_init__(self):
        if self.first_seen == 0.0:
            self.first_seen = time.time()
        if self.last_seen == 0.0:
            self.last_seen = time.time()

    @staticmethod
    def parse_address(address: str) -> tuple[str, int]:
        """Parse 'host:port' or '[ipv6]:port' into (host, port)."""
        address = address.strip()
        if address.startswith("["):
            bracket_end = address.index("]")
            host = address[1:bracket_end]
            port = int(address[bracket_end + 2:])
        else:
            host, port_s = address.rsplit(":", 1)
            port = int(port_s)
        return host, port

    @staticmethod
    def format_address(host: str, port: int) -> str:
        """Format host+port, bracketing IPv6."""
        if ":" in host:
            return f"[{host}]:{port}"
        return f"{host}:{port}"

    @property
    def address(self) -> str:
        if ":" in self.host:  # IPv6
            return f"[{self.host}]:{self.port}"
        return f"{self.host}:{self.port}"

    @property
    def hours_since_seen(self) -> float:
        return (time.time() - self.last_seen) / 3600

    @property
    def is_dead(self) -> bool:
        """Dead if not seen for 168 hours (7 days)."""
        return self.hours_since_seen > 168

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> PeerInfo:
        # Filter out unknown fields
        known = {f.name for f in cls.__dataclass_fields__.values()}
        filtered = {k: v for k, v in data.items() if k in known}
        return cls(**filtered)


@dataclass
class Message:
    """A protocol message."""
    type: str
    payload: dict = field(default_factory=dict)
    sender_id: str = ""
    timestamp: float = 0.0

    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.time()

    def encode(self) -> bytes:
        """Serialize to bytes for network transmission."""
        data = {
            "type": self.type,
            "payload": self.payload,
            "sender_id": self.sender_id,
            "ts": self.timestamp,
        }
        return (json.dumps(data, ensure_ascii=False) + "\n").encode("utf-8")

    @classmethod
    def decode(cls, data: bytes) -> Message:
        """Deserialize from bytes."""
        obj = json.loads(data.decode("utf-8").strip())
        return cls(
            type=obj["type"],
            payload=obj.get("payload", {}),
            sender_id=obj.get("sender_id", ""),
            timestamp=obj.get("ts", time.time()),
        )
