"""Property-based tests for the clustering framework.

Tests validate correctness properties from the design document using Hypothesis.
"""

import pytest
from hypothesis import given, strategies as st, assume

from statement_processor.analytics.clustering import TransactionCluster, ClusteringStrategy


# Strategies for generating test data
valid_membership_score = st.floats(min_value=0.0, max_value=1.0, allow_nan=False)
valid_index = st.integers(min_value=0, max_value=10000)
non_empty_label = st.text(min_size=1, max_size=50).filter(lambda s: s.strip())


@st.composite
def valid_memberships(draw):
    """Generate a non-empty dict of valid memberships."""
    indices = draw(st.lists(valid_index, min_size=1, max_size=20, unique=True))
    scores = draw(st.lists(valid_membership_score, min_size=len(indices), max_size=len(indices)))
    return dict(zip(indices, scores))


@st.composite
def valid_cluster(draw):
    """Generate a valid TransactionCluster."""
    memberships = draw(valid_memberships())
    label = draw(non_empty_label)
    metadata = draw(st.dictionaries(
        keys=st.text(min_size=1, max_size=20),
        values=st.one_of(st.integers(), st.text(max_size=50), st.floats(allow_nan=False)),
        max_size=5
    ))
    return TransactionCluster(memberships=memberships, label=label, metadata=metadata)


class TestClusterStructureValidity:
    """**Feature: clustering-framework, Property 2: Cluster Structure Validity**

    *For any* TransactionCluster returned by the framework:
    - memberships SHALL be a non-empty dict mapping valid DataFrame indices to confidence scores
    - all membership confidence scores SHALL be between 0.0 and 1.0 inclusive
    - label SHALL be a non-empty string
    - metadata SHALL be a dictionary (may be empty)

    **Validates: Requirements 1.4, 2.1, 3.1, 3.2**
    """

    @given(cluster=valid_cluster())
    def test_memberships_is_non_empty_dict(self, cluster: TransactionCluster):
        """Memberships SHALL be a non-empty dict."""
        assert isinstance(cluster.memberships, dict)
        assert len(cluster.memberships) > 0

    @given(cluster=valid_cluster())
    def test_membership_scores_in_valid_range(self, cluster: TransactionCluster):
        """All membership confidence scores SHALL be between 0.0 and 1.0 inclusive."""
        for idx, score in cluster.memberships.items():
            assert isinstance(idx, int)
            assert 0.0 <= score <= 1.0

    @given(cluster=valid_cluster())
    def test_label_is_non_empty_string(self, cluster: TransactionCluster):
        """Label SHALL be a non-empty string."""
        assert isinstance(cluster.label, str)
        assert len(cluster.label) > 0

    @given(cluster=valid_cluster())
    def test_metadata_is_dict(self, cluster: TransactionCluster):
        """Metadata SHALL be a dictionary (may be empty)."""
        assert isinstance(cluster.metadata, dict)

    @given(cluster=valid_cluster())
    def test_indices_property_returns_all_keys(self, cluster: TransactionCluster):
        """The indices property SHALL return all transaction indices."""
        assert set(cluster.indices) == set(cluster.memberships.keys())

    @given(
        idx=valid_index,
        score=st.floats(min_value=-100, max_value=-0.001, allow_nan=False)
        | st.floats(min_value=1.001, max_value=100, allow_nan=False)
    )
    def test_invalid_membership_score_raises_error(self, idx: int, score: float):
        """Invalid membership scores SHALL raise ValueError."""
        with pytest.raises(ValueError, match="membership score must be between"):
            TransactionCluster(memberships={idx: score}, label="test")



class TestMembershipThresholdFiltering:
    """**Feature: clustering-framework, Property 4: Membership Threshold Filtering**

    *For any* cluster and membership threshold, `indices_above_threshold(threshold)`
    SHALL return only indices with membership >= threshold.

    **Validates: Requirements 2.3**
    """

    @given(
        cluster=valid_cluster(),
        threshold=st.floats(min_value=0.0, max_value=1.0, allow_nan=False)
    )
    def test_indices_above_threshold_returns_correct_indices(
        self, cluster: TransactionCluster, threshold: float
    ):
        """indices_above_threshold SHALL return only indices with membership >= threshold."""
        result = cluster.indices_above_threshold(threshold)
        
        # All returned indices should have membership >= threshold
        for idx in result:
            assert cluster.memberships[idx] >= threshold
        
        # All indices with membership >= threshold should be in result
        expected = {idx for idx, score in cluster.memberships.items() if score >= threshold}
        assert set(result) == expected

    @given(cluster=valid_cluster())
    def test_threshold_zero_returns_all_indices(self, cluster: TransactionCluster):
        """Threshold of 0.0 SHALL return all indices."""
        result = cluster.indices_above_threshold(0.0)
        assert set(result) == set(cluster.memberships.keys())

    @given(cluster=valid_cluster())
    def test_threshold_one_returns_only_perfect_matches(self, cluster: TransactionCluster):
        """Threshold of 1.0 SHALL return only indices with membership == 1.0."""
        result = cluster.indices_above_threshold(1.0)
        expected = {idx for idx, score in cluster.memberships.items() if score >= 1.0}
        assert set(result) == expected


# ============================================================================
# ClusterRunner Property Tests
# ============================================================================

import pandas as pd
from typing import List
from statement_processor.analytics.cluster_runner import ClusterRunner


class MockStrategy(ClusteringStrategy):
    """A mock strategy for testing that tracks calls and returns configured clusters."""

    def __init__(self, strategy_name: str, clusters_to_return: List[TransactionCluster]):
        self._name = strategy_name
        self._clusters_to_return = clusters_to_return
        self.call_count = 0
        self.last_transactions = None

    @property
    def name(self) -> str:
        return self._name

    def cluster(self, transactions: pd.DataFrame) -> List[TransactionCluster]:
        self.call_count += 1
        self.last_transactions = transactions
        return self._clusters_to_return


class FailingStrategy(ClusteringStrategy):
    """A strategy that always raises an exception."""

    def __init__(self, strategy_name: str = "failing"):
        self._name = strategy_name

    @property
    def name(self) -> str:
        return self._name

    def cluster(self, transactions: pd.DataFrame) -> List[TransactionCluster]:
        raise RuntimeError("Strategy failed intentionally")


# Strategies for generating test data
valid_weight = st.floats(min_value=0.0, max_value=100.0, allow_nan=False, allow_infinity=False)
positive_weight = st.floats(min_value=0.001, max_value=100.0, allow_nan=False, allow_infinity=False)


@st.composite
def valid_transactions_df(draw):
    """Generate a valid transactions DataFrame."""
    n_rows = draw(st.integers(min_value=1, max_value=20))
    dates = [f"2024-01-{i+1:02d}" for i in range(n_rows)]
    descriptions = draw(st.lists(
        st.text(min_size=1, max_size=30).filter(lambda s: s.strip()),
        min_size=n_rows, max_size=n_rows
    ))
    amounts = draw(st.lists(
        st.floats(min_value=0.01, max_value=10000.0, allow_nan=False),
        min_size=n_rows, max_size=n_rows
    ))
    return pd.DataFrame({
        "date": dates,
        "description": descriptions,
        "amount": amounts
    })


@st.composite
def cluster_for_df(draw, df_size: int, label: str = None):
    """Generate a cluster with valid indices for a DataFrame of given size."""
    if df_size == 0:
        return None
    indices = draw(st.lists(
        st.integers(min_value=0, max_value=df_size - 1),
        min_size=1, max_size=min(5, df_size),
        unique=True
    ))
    scores = draw(st.lists(valid_membership_score, min_size=len(indices), max_size=len(indices)))
    memberships = dict(zip(indices, scores))
    cluster_label = label or draw(non_empty_label)
    return TransactionCluster(memberships=memberships, label=cluster_label)


class TestAllStrategiesApplied:
    """**Feature: clustering-framework, Property 1: All Strategies Applied**

    *For any* set of registered strategies and input transactions, the ClusterRunner
    SHALL call the cluster method on every registered strategy.

    **Validates: Requirements 1.1**
    """

    @given(df=valid_transactions_df(), num_strategies=st.integers(min_value=1, max_value=5))
    def test_all_registered_strategies_are_called(self, df: pd.DataFrame, num_strategies: int):
        """All registered strategies SHALL be called when run() is invoked."""
        runner = ClusterRunner()
        strategies = []

        for i in range(num_strategies):
            strategy = MockStrategy(f"strategy_{i}", [])
            strategies.append(strategy)
            runner.register_strategy(strategy)

        runner.run(df)

        # Verify all strategies were called exactly once
        for strategy in strategies:
            assert strategy.call_count == 1, f"Strategy {strategy.name} was called {strategy.call_count} times"

    @given(df=valid_transactions_df())
    def test_strategies_receive_transactions_dataframe(self, df: pd.DataFrame):
        """Each strategy SHALL receive the transactions DataFrame."""
        runner = ClusterRunner()
        strategy = MockStrategy("test", [])
        runner.register_strategy(strategy)

        runner.run(df)

        assert strategy.last_transactions is not None
        assert len(strategy.last_transactions) == len(df)


class TestWeightedMembershipCombination:
    """**Feature: clustering-framework, Property 3: Weighted Membership Combination**

    *For any* transaction that appears in overlapping clusters from different strategies,
    the combined membership score SHALL equal the weighted average:
    sum(membership_i * weight_i) / sum(weight_i).

    **Validates: Requirements 2.2**
    """

    @given(
        weight1=positive_weight,
        weight2=positive_weight,
        score1=valid_membership_score,
        score2=valid_membership_score
    )
    def test_weighted_average_for_overlapping_memberships(
        self, weight1: float, weight2: float, score1: float, score2: float
    ):
        """Overlapping memberships SHALL be combined using weighted average."""
        # Create two strategies that both claim index 0 with different scores
        cluster1 = TransactionCluster(memberships={0: score1}, label="vendor")
        cluster2 = TransactionCluster(memberships={0: score2}, label="vendor")

        strategy1 = MockStrategy("s1", [cluster1])
        strategy2 = MockStrategy("s2", [cluster2])

        # Use cascade=False to enable merge mode for weighted averaging
        runner = ClusterRunner(cascade=False)
        runner.register_strategy(strategy1, weight1)
        runner.register_strategy(strategy2, weight2)

        df = pd.DataFrame({
            "date": ["2024-01-01"],
            "description": ["Test"],
            "amount": [100.0]
        })

        result = runner.run(df)

        # Find the vendor cluster
        vendor_clusters = [c for c in result if c.label == "vendor"]
        assert len(vendor_clusters) == 1

        # Verify weighted average formula
        expected = (score1 * weight1 + score2 * weight2) / (weight1 + weight2)
        actual = vendor_clusters[0].memberships[0]
        assert abs(actual - expected) < 1e-9, f"Expected {expected}, got {actual}"

    @given(
        weight1=positive_weight,
        weight2=positive_weight,
        score1=valid_membership_score,
        score2=valid_membership_score
    )
    def test_non_overlapping_memberships_preserved(
        self, weight1: float, weight2: float, score1: float, score2: float
    ):
        """Non-overlapping memberships SHALL be preserved with their original scores."""
        # Strategy 1 claims index 0, Strategy 2 claims index 1
        cluster1 = TransactionCluster(memberships={0: score1}, label="vendor")
        cluster2 = TransactionCluster(memberships={1: score2}, label="vendor")

        strategy1 = MockStrategy("s1", [cluster1])
        strategy2 = MockStrategy("s2", [cluster2])

        # Use cascade=False to enable merge mode for weighted averaging
        runner = ClusterRunner(cascade=False)
        runner.register_strategy(strategy1, weight1)
        runner.register_strategy(strategy2, weight2)

        df = pd.DataFrame({
            "date": ["2024-01-01", "2024-01-02"],
            "description": ["Test1", "Test2"],
            "amount": [100.0, 200.0]
        })

        result = runner.run(df)

        vendor_clusters = [c for c in result if c.label == "vendor"]
        assert len(vendor_clusters) == 1

        # Non-overlapping indices should have their original scores (use approx for float precision)
        assert abs(vendor_clusters[0].memberships[0] - score1) < 1e-9
        assert abs(vendor_clusters[0].memberships[1] - score2) < 1e-9


class TestLabelFiltering:
    """**Feature: clustering-framework, Property 5: Label Filtering**

    *For any* set of clusters and label filter, the filtered result SHALL contain
    only clusters with matching label.

    **Validates: Requirements 3.3**
    """

    @given(
        df=valid_transactions_df(),
        target_label=non_empty_label
    )
    def test_run_by_label_returns_only_matching_labels(
        self, df: pd.DataFrame, target_label: str
    ):
        """run_by_label SHALL return only clusters with matching label."""
        assume(len(df) >= 2)

        # Create clusters with different labels
        cluster1 = TransactionCluster(memberships={0: 1.0}, label=target_label)
        cluster2 = TransactionCluster(memberships={1: 1.0}, label="other_label")

        strategy = MockStrategy("test", [cluster1, cluster2])
        runner = ClusterRunner()
        runner.register_strategy(strategy)

        result = runner.run_by_label(df, target_label)

        # All returned clusters should have the target label
        for cluster in result:
            assert cluster.label == target_label

    @given(df=valid_transactions_df())
    def test_run_by_label_excludes_non_matching(self, df: pd.DataFrame):
        """run_by_label SHALL exclude clusters with non-matching labels."""
        assume(len(df) >= 2)

        cluster1 = TransactionCluster(memberships={0: 1.0}, label="vendor")
        cluster2 = TransactionCluster(memberships={1: 1.0}, label="recurring")

        strategy = MockStrategy("test", [cluster1, cluster2])
        runner = ClusterRunner()
        runner.register_strategy(strategy)

        vendor_result = runner.run_by_label(df, "vendor")
        recurring_result = runner.run_by_label(df, "recurring")

        assert all(c.label == "vendor" for c in vendor_result)
        assert all(c.label == "recurring" for c in recurring_result)


class TestEdgeCases:
    """Tests for edge case handling in ClusterRunner.

    **Validates: Requirements 2.4, 4.4, 5.3**
    """

    def test_empty_transactions_returns_empty_list(self):
        """Empty transactions DataFrame SHALL return empty list."""
        runner = ClusterRunner()
        strategy = MockStrategy("test", [TransactionCluster(memberships={0: 1.0}, label="vendor")])
        runner.register_strategy(strategy)

        empty_df = pd.DataFrame({"date": [], "description": [], "amount": []})
        result = runner.run(empty_df)

        assert result == []
        assert strategy.call_count == 0  # Strategy should not be called

    def test_no_strategies_returns_empty_list(self):
        """No registered strategies SHALL return empty list."""
        runner = ClusterRunner()
        df = pd.DataFrame({
            "date": ["2024-01-01"],
            "description": ["Test"],
            "amount": [100.0]
        })

        result = runner.run(df)

        assert result == []

    def test_negative_weight_raises_value_error(self):
        """Negative weight SHALL raise ValueError."""
        runner = ClusterRunner()
        strategy = MockStrategy("test", [])

        with pytest.raises(ValueError, match="Weight must be non-negative"):
            runner.register_strategy(strategy, weight=-1.0)

    def test_strategy_exception_continues_with_others(self):
        """Strategy exception SHALL log error and continue with other strategies."""
        runner = ClusterRunner()

        # Register a failing strategy and a working strategy
        failing = FailingStrategy("failing")
        working = MockStrategy("working", [TransactionCluster(memberships={0: 1.0}, label="vendor")])

        runner.register_strategy(failing)
        runner.register_strategy(working)

        df = pd.DataFrame({
            "date": ["2024-01-01"],
            "description": ["Test"],
            "amount": [100.0]
        })

        # Should not raise, should return results from working strategy
        result = runner.run(df)

        assert len(result) == 1
        assert result[0].label == "vendor"
        assert working.call_count == 1


class TestStrategyErrorIsolation:
    """**Feature: clustering-framework, Property 6: Strategy Error Isolation**

    *For any* strategy that throws an exception during clustering, the ClusterRunner
    SHALL catch the error, log it, and continue processing other strategies without crashing.

    **Validates: Requirements 4.4**
    """

    @given(
        df=valid_transactions_df(),
        num_working=st.integers(min_value=1, max_value=3),
        num_failing=st.integers(min_value=1, max_value=3)
    )
    def test_failing_strategies_do_not_crash_runner(
        self, df: pd.DataFrame, num_working: int, num_failing: int
    ):
        """Failing strategies SHALL NOT crash the ClusterRunner."""
        # Use cascade=False so all strategies see all transactions
        runner = ClusterRunner(cascade=False)
        working_strategies = []

        # Register failing strategies
        for i in range(num_failing):
            runner.register_strategy(FailingStrategy(f"failing_{i}"))

        # Register working strategies that return clusters with unique labels
        for i in range(num_working):
            cluster = TransactionCluster(
                memberships={0: 1.0},
                label=f"label_{i}"
            )
            strategy = MockStrategy(f"working_{i}", [cluster])
            working_strategies.append(strategy)
            runner.register_strategy(strategy)

        # Should not raise exception
        result = runner.run(df)

        # All working strategies should have been called
        for strategy in working_strategies:
            assert strategy.call_count == 1

        # Should have results from working strategies
        assert len(result) == num_working

    @given(df=valid_transactions_df())
    def test_all_failing_strategies_returns_empty(self, df: pd.DataFrame):
        """When all strategies fail, SHALL return empty list without crashing."""
        runner = ClusterRunner()
        runner.register_strategy(FailingStrategy("fail1"))
        runner.register_strategy(FailingStrategy("fail2"))

        result = runner.run(df)

        assert result == []

    @given(
        df=valid_transactions_df(),
        score=valid_membership_score
    )
    def test_results_from_working_strategies_preserved(
        self, df: pd.DataFrame, score: float
    ):
        """Results from working strategies SHALL be preserved when others fail."""
        runner = ClusterRunner()

        # Failing strategy first
        runner.register_strategy(FailingStrategy("failing"))

        # Working strategy with specific cluster
        expected_cluster = TransactionCluster(
            memberships={0: score},
            label="vendor",
            metadata={"source": "working"}
        )
        working = MockStrategy("working", [expected_cluster])
        runner.register_strategy(working)

        result = runner.run(df)

        assert len(result) == 1
        assert result[0].label == "vendor"
        assert result[0].memberships[0] == score


class TestDefaultWeight:
    """**Feature: clustering-framework, Property 7: Default Weight**

    *For any* strategy registered without an explicit weight, the ClusterRunner
    SHALL use a default weight of 1.0.

    **Validates: Requirements 5.1**
    """

    @given(score=valid_membership_score)
    def test_default_weight_is_one(self, score: float):
        """Strategies registered without weight SHALL use default weight of 1.0."""
        # Register two strategies: one with explicit weight 1.0, one with default
        cluster1 = TransactionCluster(memberships={0: score}, label="vendor")
        cluster2 = TransactionCluster(memberships={0: score}, label="vendor")

        strategy1 = MockStrategy("explicit", [cluster1])
        strategy2 = MockStrategy("default", [cluster2])

        runner = ClusterRunner()
        runner.register_strategy(strategy1, weight=1.0)  # Explicit
        runner.register_strategy(strategy2)  # Default

        df = pd.DataFrame({
            "date": ["2024-01-01"],
            "description": ["Test"],
            "amount": [100.0]
        })

        result = runner.run(df)

        # With equal weights (1.0 each), the combined score should equal the original
        # since both strategies return the same score
        vendor_clusters = [c for c in result if c.label == "vendor"]
        assert len(vendor_clusters) == 1
        assert vendor_clusters[0].memberships[0] == score

    @given(
        explicit_weight=positive_weight,
        score1=valid_membership_score,
        score2=valid_membership_score
    )
    def test_default_weight_in_combination(
        self, explicit_weight: float, score1: float, score2: float
    ):
        """Default weight of 1.0 SHALL be used in weighted average calculations."""
        cluster1 = TransactionCluster(memberships={0: score1}, label="vendor")
        cluster2 = TransactionCluster(memberships={0: score2}, label="vendor")

        strategy1 = MockStrategy("explicit", [cluster1])
        strategy2 = MockStrategy("default", [cluster2])

        # Use cascade=False to enable merge mode for weighted averaging
        runner = ClusterRunner(cascade=False)
        runner.register_strategy(strategy1, weight=explicit_weight)
        runner.register_strategy(strategy2)  # Default weight = 1.0

        df = pd.DataFrame({
            "date": ["2024-01-01"],
            "description": ["Test"],
            "amount": [100.0]
        })

        result = runner.run(df)

        # Expected: weighted average with explicit_weight and 1.0
        expected = (score1 * explicit_weight + score2 * 1.0) / (explicit_weight + 1.0)
        vendor_clusters = [c for c in result if c.label == "vendor"]
        assert len(vendor_clusters) == 1
        actual = vendor_clusters[0].memberships[0]
        assert abs(actual - expected) < 1e-9
