"""Tests for custom operator registration and usage."""

import pytest
import re

from nlql import NLQL
from nlql.adapters import MemoryAdapter
from nlql.registry.operators import register_operator, get_operator
from nlql.errors import NLQLRegistryError


@pytest.fixture
def adapter():
    """Create a memory adapter with test data."""
    adapter = MemoryAdapter()
    
    adapter.add_text(
        "Email: alice@example.com for more information",
        {"category": "contact"}
    )
    
    adapter.add_text(
        "Visit our website at https://example.com",
        {"category": "web"}
    )
    
    adapter.add_text(
        "The product code is ABC-12345-XYZ",
        {"category": "product"}
    )
    
    adapter.add_text(
        "Temperature: 25.5°C, Humidity: 60%",
        {"category": "weather"}
    )
    
    return adapter


def test_register_custom_operator_basic():
    """Test basic custom operator registration."""
    
    @register_operator("STARTS_WITH")
    def starts_with(text: str, prefix: str) -> bool:
        """Check if text starts with prefix."""
        return text.startswith(prefix)
    
    # Verify operator is registered
    op = get_operator("STARTS_WITH")
    assert op is not None
    assert op("hello world", "hello") is True
    assert op("hello world", "world") is False


def test_custom_operator_in_where_clause(adapter):
    """Test using custom operator in WHERE clause."""
    
    @register_operator("STARTS_WITH")
    def starts_with(text: str, prefix: str) -> bool:
        """Check if text starts with prefix."""
        return text.startswith(prefix)
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE STARTS_WITH(content, "Email")
    """
    
    results = nlql.execute(query)
    
    assert len(results) == 1
    assert "alice@example.com" in results[0].content


def test_custom_operator_ends_with(adapter):
    """Test ENDS_WITH custom operator."""
    
    @register_operator("ENDS_WITH")
    def ends_with(text: str, suffix: str) -> bool:
        """Check if text ends with suffix."""
        return text.endswith(suffix)
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE ENDS_WITH(content, "information")
    """
    
    results = nlql.execute(query)
    
    assert len(results) == 1
    assert "alice@example.com" in results[0].content


def test_custom_operator_regex(adapter):
    """Test REGEX custom operator."""
    
    @register_operator("REGEX")
    def regex_match(text: str, pattern: str) -> bool:
        """Check if text matches regex pattern."""
        return bool(re.search(pattern, text))
    
    nlql = NLQL(adapter=adapter)
    
    # Find documents with email addresses
    query = """
    SELECT CHUNK
    WHERE REGEX(content, "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}")
    """
    
    results = nlql.execute(query)
    
    assert len(results) == 1
    assert "alice@example.com" in results[0].content


def test_custom_operator_has_number(adapter):
    """Test HAS_NUMBER custom operator."""
    
    @register_operator("HAS_NUMBER")
    def has_number(text: str) -> bool:
        """Check if text contains any digit."""
        return any(char.isdigit() for char in text)
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE HAS_NUMBER(content)
    """
    
    results = nlql.execute(query)

    # Should return 2 results (product code and temperature have digits)
    assert len(results) == 2


def test_custom_operator_with_metadata(adapter):
    """Test custom operator with metadata field."""
    
    @register_operator("CATEGORY_IS")
    def category_is(metadata_value: str, expected: str) -> bool:
        """Check if category matches expected value."""
        return metadata_value == expected
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE CATEGORY_IS(META("category"), "product")
    """
    
    results = nlql.execute(query)
    
    assert len(results) == 1
    assert "ABC-12345-XYZ" in results[0].content


def test_custom_operator_returns_number(adapter):
    """Test custom operator that returns a number."""
    
    @register_operator("DIGIT_COUNT")
    def digit_count(text: str) -> int:
        """Count number of digits in text."""
        return sum(1 for char in text if char.isdigit())
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE DIGIT_COUNT(content) >= 5 AND META("category") == "product"
    """

    results = nlql.execute(query)

    # Should return 1 result (product code has 5 digits and category is product)
    assert len(results) == 1
    assert "ABC-12345" in results[0].content


def test_register_operator_lowercase_name():
    """Test that registering with lowercase name raises error."""
    from nlql.registry.operators import _global_operator_registry
    
    with pytest.raises(NLQLRegistryError, match="must be uppercase"):
        _global_operator_registry.register("lowercase", lambda x: x)


def test_custom_operator_combined_with_builtin(adapter):
    """Test combining custom operator with built-in operators."""
    
    @register_operator("STARTS_WITH")
    def starts_with(text: str, prefix: str) -> bool:
        return text.startswith(prefix)
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE STARTS_WITH(content, "Email") OR CONTAINS(content, "website")
    """
    
    results = nlql.execute(query)
    
    # Should return 2 results
    assert len(results) == 2


def test_custom_operator_in_complex_expression(adapter):
    """Test custom operator in complex logical expression."""
    
    @register_operator("HAS_NUMBER")
    def has_number(text: str) -> bool:
        return any(char.isdigit() for char in text)
    
    @register_operator("STARTS_WITH")
    def starts_with(text: str, prefix: str) -> bool:
        return text.startswith(prefix)
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE 
        HAS_NUMBER(content)
        AND (STARTS_WITH(content, "The") OR STARTS_WITH(content, "Visit"))
    """
    
    results = nlql.execute(query)

    # Should return 1 result (product code has numbers and starts with "The")
    assert len(results) == 1
    assert "product code" in results[0].content

