"""Confidence scoring for local model responses.

Determines whether a response is good enough or should be supplemented
by the P2P network.

Thresholds:
  >= 0.8  → Use local response directly
  0.3-0.8 → Use local + ask network, pick best
  < 0.3   → Discard local, only use network
"""
from __future__ import annotations

import logging
import math
import re
from typing import Any

log = logging.getLogger(__name__)

# Threshold constants
CONFIDENCE_HIGH = 0.8
CONFIDENCE_LOW = 0.3


class ConfidenceScorer:
    """Score response confidence using multiple heuristics."""

    def __init__(self, expertise_tags: list[str] | None = None,
                 specialization: str = ""):
        self.expertise_tags = [t.lower() for t in (expertise_tags or [])]
        self.specialization = specialization.lower()

    def score(self, query: str, response: str,
              logprobs: list[float] | None = None) -> float:
        """Compute confidence score (0.0 - 1.0) for a response.

        Combines multiple signals:
        - Response quality heuristics (length, coherence)
        - Domain match against expertise tags
        - Logprobs perplexity (if available)
        """
        scores: list[tuple[float, float]] = []  # (score, weight)

        # 1. Response quality heuristics
        quality = self._quality_score(response)
        scores.append((quality, 0.3))

        # 2. Domain match
        domain = self._domain_score(query)
        scores.append((domain, 0.25))

        # 3. Logprobs-based confidence
        if logprobs:
            lp = self._logprobs_score(logprobs)
            scores.append((lp, 0.35))
        else:
            # Without logprobs, increase weight of other signals
            scores.append((quality, 0.15))  # Extra weight on quality

        # 4. Self-doubt detection
        doubt = self._doubt_score(response)
        scores.append((doubt, 0.1))

        # Weighted average
        total_weight = sum(w for _, w in scores)
        if total_weight == 0:
            return 0.5
        weighted = sum(s * w for s, w in scores) / total_weight

        return max(0.0, min(1.0, weighted))

    def _quality_score(self, response: str) -> float:
        """Heuristic quality based on response characteristics."""
        if not response or not response.strip():
            return 0.0

        length = len(response.strip())

        # Very short responses are often low quality
        if length < 20:
            return 0.2
        elif length < 50:
            return 0.4
        elif length < 100:
            return 0.6

        # Check for structured content (lists, code, etc.)
        has_structure = bool(re.search(r'[\n\-\*\d\.]', response))
        has_sentences = response.count('.') >= 2 or response.count('!') >= 1

        score = 0.7
        if has_structure:
            score += 0.1
        if has_sentences:
            score += 0.1
        if length > 500:
            score += 0.1

        return min(1.0, score)

    def _domain_score(self, query: str) -> float:
        """How well does the query match our expertise?"""
        if not self.expertise_tags and not self.specialization:
            return 0.5  # Neutral — no specialization declared

        query_lower = query.lower()
        words = set(re.findall(r'\w+', query_lower))

        # Check tag matches
        matches = 0
        for tag in self.expertise_tags:
            tag_words = set(re.findall(r'\w+', tag))
            if tag_words & words:
                matches += 1
            elif tag in query_lower:
                matches += 1

        # Check specialization
        if self.specialization:
            spec_words = set(re.findall(r'\w+', self.specialization))
            if spec_words & words:
                matches += 2

        if not self.expertise_tags and not self.specialization:
            return 0.5

        total = len(self.expertise_tags) + (2 if self.specialization else 0)
        if total == 0:
            return 0.5

        # Scale: 0 matches → 0.3, all match → 1.0
        ratio = matches / max(total, 1)
        return 0.3 + 0.7 * ratio

    def _logprobs_score(self, logprobs: list[float]) -> float:
        """Convert logprobs to confidence score.

        Lower perplexity (higher logprobs) = more confident.
        """
        if not logprobs:
            return 0.5

        # Average log probability
        avg_lp = sum(logprobs) / len(logprobs)

        # Convert to 0-1 range
        # avg_lp near 0 = very confident, near -inf = very uncertain
        # Typical range: -0.5 (confident) to -5.0 (uncertain)
        # Sigmoid-like mapping
        score = 1.0 / (1.0 + math.exp(-2.0 * (avg_lp + 2.0)))

        return max(0.0, min(1.0, score))

    def _doubt_score(self, response: str) -> float:
        """Detect self-doubt phrases in the response.

        Returns 1.0 (confident) to 0.0 (very doubtful).
        """
        doubt_phrases = [
            "ich bin mir nicht sicher",
            "i'm not sure",
            "i don't know",
            "ich weiß nicht",
            "könnte falsch sein",
            "might be wrong",
            "nicht mein fachgebiet",
            "not my area",
            "keine ahnung",
            "unsicher",
            "vermutlich",
            "possibly",
            "perhaps",
            "maybe",
            "vielleicht",
        ]

        response_lower = response.lower()
        doubt_count = sum(1 for p in doubt_phrases if p in response_lower)

        if doubt_count == 0:
            return 0.9
        elif doubt_count == 1:
            return 0.5
        else:
            return 0.2


def should_ask_network(confidence: float) -> str:
    """Decide routing based on confidence.

    Returns:
        'local'   — confidence >= 0.8, use local only
        'both'    — 0.3 <= confidence < 0.8, ask network too
        'network' — confidence < 0.3, prefer network
    """
    if confidence >= CONFIDENCE_HIGH:
        return "local"
    elif confidence >= CONFIDENCE_LOW:
        return "both"
    else:
        return "network"
