"""ClusterRunner orchestrates clustering strategies and combines results.

This module provides the main entry point for running multiple clustering
strategies on transaction data and merging their results.
"""

import logging
from collections import defaultdict
from typing import Dict, List, Tuple

import pandas as pd

from statement_processor.analytics.clustering import ClusteringStrategy, TransactionCluster

logger = logging.getLogger(__name__)


class ClusterRunner:
    """Orchestrates clustering strategies and combines results.

    The ClusterRunner manages a collection of clustering strategies, each with
    an associated weight. When run() is called, it executes all strategies and
    either cascades (claimed transactions excluded from later strategies) or
    runs all strategies on all transactions and merges results.

    Attributes:
        _strategies: List of (strategy, weight) tuples.
        _cascade: If True, claimed transactions are excluded from later strategies.
    """

    def __init__(self, cascade: bool = True) -> None:
        """Initialize the cluster runner.

        Args:
            cascade: If True (default), transactions claimed by earlier strategies
                are excluded from later strategies. If False, all strategies see
                all transactions and results are merged by label.
        """
        self._strategies: List[Tuple[ClusteringStrategy, float]] = []
        self._cascade = cascade

    def register_strategy(
        self, strategy: ClusteringStrategy, weight: float = 1.0
    ) -> None:
        """Register a clustering strategy with optional weight.

        Args:
            strategy: Strategy instance conforming to ClusteringStrategy interface.
            weight: Weight for this strategy when combining overlapping clusters.
                   Default is 1.0.

        Raises:
            ValueError: If weight is negative.
        """
        if weight < 0:
            raise ValueError(f"Weight must be non-negative, got {weight}")
        self._strategies.append((strategy, weight))

    def run(
        self, transactions: pd.DataFrame, min_confidence: float = 0.0
    ) -> List[TransactionCluster]:
        """Apply all strategies and return clusters.

        Behavior depends on cascade setting:
        - cascade=True: Strategies run sequentially; claimed transactions are
          excluded from later strategies (prevents double-matching).
        - cascade=False: All strategies see all transactions; results are
          merged by label using weighted averaging.

        Args:
            transactions: DataFrame of transactions with columns: date, description, amount.
            min_confidence: Minimum confidence threshold for cluster membership.

        Returns:
            List of TransactionCluster objects, filtered by min_confidence.
        """
        if transactions.empty:
            return []
        if not self._strategies:
            return []

        if self._cascade:
            return self._run_cascade(transactions, min_confidence)
        else:
            return self._run_merge(transactions, min_confidence)

    def _run_cascade(
        self, transactions: pd.DataFrame, min_confidence: float
    ) -> List[TransactionCluster]:
        """Run strategies sequentially, excluding claimed transactions."""
        claimed_indices: set = set()
        all_clusters: List[TransactionCluster] = []

        for strategy, weight in self._strategies:
            try:
                unclaimed_mask = ~transactions.index.isin(claimed_indices)
                unclaimed_df = transactions[unclaimed_mask]

                if unclaimed_df.empty:
                    break

                clusters = strategy.cluster(unclaimed_df)

                for cluster in clusters:
                    filtered_memberships: Dict[int, float] = {}
                    for idx, score in cluster.memberships.items():
                        if score >= min_confidence:
                            filtered_memberships[idx] = score
                            claimed_indices.add(idx)

                    if filtered_memberships:
                        all_clusters.append(
                            TransactionCluster(
                                memberships=filtered_memberships,
                                label=cluster.label,
                                metadata=cluster.metadata,
                            )
                        )

            except Exception as e:
                logger.error(
                    f"Strategy '{strategy.name}' failed with error: {e}. "
                    "Continuing with other strategies."
                )
                continue

        return all_clusters

    def _run_merge(
        self, transactions: pd.DataFrame, min_confidence: float
    ) -> List[TransactionCluster]:
        """Run all strategies on all transactions and merge by label."""
        all_clusters: List[Tuple[TransactionCluster, float]] = []

        for strategy, weight in self._strategies:
            try:
                clusters = strategy.cluster(transactions)
                for cluster in clusters:
                    all_clusters.append((cluster, weight))
            except Exception as e:
                logger.error(
                    f"Strategy '{strategy.name}' failed with error: {e}. "
                    "Continuing with other strategies."
                )
                continue

        merged = self._merge_clusters_by_label(all_clusters)

        if min_confidence > 0.0:
            merged = [
                TransactionCluster(
                    memberships={
                        idx: score
                        for idx, score in cluster.memberships.items()
                        if score >= min_confidence
                    },
                    label=cluster.label,
                    metadata=cluster.metadata,
                )
                for cluster in merged
                if any(
                    score >= min_confidence for score in cluster.memberships.values()
                )
            ]

        return merged

    def run_by_label(
        self, transactions: pd.DataFrame, label: str
    ) -> List[TransactionCluster]:
        """Return only clusters matching the specified label.

        Args:
            transactions: DataFrame of transactions.
            label: Cluster label to filter by.

        Returns:
            List of TransactionCluster objects with matching label.
        """
        all_clusters = self.run(transactions)
        return [cluster for cluster in all_clusters if cluster.label == label]

    def _merge_clusters_by_label(
        self, all_clusters: List[Tuple[TransactionCluster, float]]
    ) -> List[TransactionCluster]:
        """Merge clusters with same label using weighted averaging."""
        if not all_clusters:
            return []

        clusters_by_label: Dict[str, List[Tuple[TransactionCluster, float]]] = (
            defaultdict(list)
        )
        for cluster, weight in all_clusters:
            clusters_by_label[cluster.label].append((cluster, weight))

        merged_clusters: List[TransactionCluster] = []

        for label, label_clusters in clusters_by_label.items():
            weighted_sums: Dict[int, float] = defaultdict(float)
            total_weights: Dict[int, float] = defaultdict(float)
            best_metadata: Dict = {}
            best_weight = -1.0

            for cluster, weight in label_clusters:
                for idx, membership in cluster.memberships.items():
                    weighted_sums[idx] += membership * weight
                    total_weights[idx] += weight

                if weight > best_weight:
                    best_weight = weight
                    best_metadata = cluster.metadata.copy()

            merged_memberships: Dict[int, float] = {}
            for idx in weighted_sums:
                merged_memberships[idx] = weighted_sums[idx] / total_weights[idx]

            if merged_memberships:
                merged_clusters.append(
                    TransactionCluster(
                        memberships=merged_memberships,
                        label=label,
                        metadata=best_metadata,
                    )
                )

        return merged_clusters
