"""Statement validation for verifying parsed transactions against statement totals.

This module validates that parsed transactions match the summary data from the statement,
catching potential parsing errors or missing transactions.

Uses Decimal for precise currency comparisons to avoid floating point errors.
"""

import logging
from dataclasses import dataclass, field
from decimal import Decimal
from typing import List, Optional

from statement_processor.models import Statement, StatementMetadata

logger = logging.getLogger(__name__)


class ValidationError(Exception):
    """Raised when statement validation fails in strict mode.

    Attributes:
        source_file: The statement file that failed validation
        errors: List of validation error messages
    """

    def __init__(self, source_file: str, errors: List[str]):
        self.source_file = source_file
        self.errors = errors
        message = f"Validation failed for {source_file}:\n" + "\n".join(
            f"  - {e}" for e in errors
        )
        super().__init__(message)


def _to_decimal(value: float) -> Decimal:
    """Convert float to Decimal with 2 decimal places for currency precision."""
    return Decimal(str(value)).quantize(Decimal("0.01"))


@dataclass
class ValidationResult:
    """Result of statement validation checks.

    Attributes:
        is_valid: Whether all validation checks passed
        errors: List of error messages (transaction sum mismatch, etc.)
        warnings: List of warning messages (informational issues)
        transaction_sum: Calculated sum of parsed transactions
        expected_total: Expected total from statement metadata
    """

    is_valid: bool = True
    errors: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)
    transaction_sum: float = 0.0
    expected_total: float = 0.0


class StatementValidator:
    """Validates parsed statement data against summary information.

    Performs sanity checks to ensure parsed transactions match the
    statement's summary data, helping catch parsing errors.
    """

    def __init__(self, tolerance: float = 0.0, strict: bool = True):
        """Initialize validator with tolerance for comparisons.

        Args:
            tolerance: Absolute tolerance for amount comparisons (default $0.00)
            strict: If True, raise ValidationError on validation failures
        """
        self._tolerance = tolerance
        self._strict = strict

    def validate(self, statement: Statement) -> ValidationResult:
        """Run validation checks on a parsed statement.

        Args:
            statement: Parsed Statement object to validate

        Returns:
            ValidationResult with validation status and any errors/warnings

        Raises:
            ValidationError: If strict mode is enabled and validation fails
        """
        result = ValidationResult()
        metadata = statement.metadata
        transactions = statement.transactions

        # Check if we have expected total in metadata
        expected_total = self._get_expected_total(metadata)
        if expected_total is None:
            logger.debug(
                f"Skipping validation for {statement.source_file}: "
                "no expected total in metadata"
            )
            return result

        # Calculate transaction sum (positive amounts = charges)
        transaction_sum = sum(
            _to_decimal(tx.amount) for tx in transactions if tx.amount > 0
        )

        result.transaction_sum = float(transaction_sum)
        result.expected_total = expected_total

        # Compare sums
        expected_decimal = _to_decimal(expected_total)
        difference = abs(transaction_sum - expected_decimal)

        if difference > _to_decimal(self._tolerance):
            error = (
                f"Transaction sum mismatch for {statement.source_file}: "
                f"parsed ${transaction_sum}, expected ${expected_decimal} "
                f"(difference: ${difference})"
            )
            result.errors.append(error)
            result.is_valid = False
            logger.error(error)

            if self._strict:
                raise ValidationError(statement.source_file, result.errors)

        return result

    def _get_expected_total(self, metadata: StatementMetadata) -> Optional[float]:
        """Extract expected transaction total from metadata.

        Looks for common fields that contain the expected total:
        - extra['purchases'] - total purchases amount
        - extra['new_charges'] - new charges for the period
        - extra['total_charges'] - total charges

        Args:
            metadata: Statement metadata

        Returns:
            Expected total amount, or None if not available
        """
        extra = metadata.extra

        # Check common field names for expected total
        for field_name in ["purchases", "new_charges", "total_charges", "total_new_charges"]:
            if field_name in extra and extra[field_name]:
                try:
                    return float(extra[field_name])
                except (ValueError, TypeError):
                    continue

        return None
