"""Tests for custom function registration and usage."""

import pytest
from datetime import datetime, timedelta

from nlql import NLQL
from nlql.adapters import MemoryAdapter
from nlql.registry.functions import register_function, get_function
from nlql.errors import NLQLRegistryError


@pytest.fixture
def adapter():
    """Create a memory adapter with test data."""
    adapter = MemoryAdapter()
    
    adapter.add_text(
        "The quick brown fox jumps over the lazy dog",
        {"author": "Alice", "date": "2024-01-15"}
    )
    
    adapter.add_text(
        "Python is a powerful programming language",
        {"author": "Bob", "date": "2024-02-20"}
    )
    
    adapter.add_text(
        "Machine learning enables computers to learn from data",
        {"author": "Charlie", "date": "2024-03-10"}
    )
    
    return adapter


def test_register_custom_function_basic():
    """Test basic custom function registration."""
    
    @register_function("WORD_COUNT")
    def word_count(text: str) -> int:
        """Count words in text."""
        return len(text.split())
    
    # Verify function is registered
    func = get_function("WORD_COUNT")
    assert func is not None
    assert func("hello world") == 2


def test_custom_function_in_where_clause(adapter):
    """Test using custom function in WHERE clause."""

    @register_function("NUMWORDS")
    def word_count(text: str) -> int:
        """Count words in text."""
        return len(text.split())

    nlql = NLQL(adapter=adapter)

    # Find documents with more than 8 words
    query = """
    SELECT CHUNK
    WHERE NUMWORDS(content) > 8
    """

    results = nlql.execute(query)

    # Should return 1 result (first document has 9 words)
    assert len(results) == 1
    assert "quick brown fox" in results[0].content


def test_custom_function_with_multiple_args(adapter):
    """Test custom function with multiple arguments."""

    @register_function("OCCURRENCES")
    def count_substr(text: str, substring: str) -> int:
        """Count occurrences of substring in text."""
        return text.lower().count(substring.lower())

    nlql = NLQL(adapter=adapter)

    query = """
    SELECT CHUNK
    WHERE OCCURRENCES(content, "the") > 0
    """

    results = nlql.execute(query)

    # Should return 1 result (first document has "the")
    assert len(results) == 1
    assert "quick brown fox" in results[0].content


def test_custom_function_returns_string(adapter):
    """Test custom function that returns a string."""
    
    @register_function("UPPERCASE")
    def uppercase(text: str) -> str:
        """Convert text to uppercase."""
        return text.upper()
    
    nlql = NLQL(adapter=adapter)
    
    query = """
    SELECT CHUNK
    WHERE UPPERCASE(META("author")) == "ALICE"
    """
    
    results = nlql.execute(query)
    
    assert len(results) == 1
    assert results[0].metadata["author"] == "Alice"


def test_custom_function_returns_boolean(adapter):
    """Test custom function that returns a boolean."""

    @register_function("TOOLONG")
    def is_long(text: str) -> bool:
        """Check if text is longer than 50 characters."""
        return len(text) > 50

    nlql = NLQL(adapter=adapter)

    query = """
    SELECT CHUNK
    WHERE TOOLONG(content)
    """

    results = nlql.execute(query)

    # Should return 1 result (third document)
    assert len(results) == 1
    assert "Machine learning" in results[0].content


def test_custom_function_with_date_operations(adapter):
    """Test custom function with date operations."""
    
    @register_function("DAYS_AGO")
    def days_ago(days: int) -> datetime:
        """Get datetime N days ago."""
        return datetime.now() - timedelta(days=days)
    
    # Note: This test demonstrates the concept, but actual date comparison
    # would require the function to be used in a comparison context
    func = get_function("DAYS_AGO")
    assert func is not None
    
    result = func(7)
    assert isinstance(result, datetime)
    assert result < datetime.now()


def test_register_function_empty_name():
    """Test that registering with empty name raises error."""
    from nlql.registry.functions import _global_function_registry
    
    with pytest.raises(NLQLRegistryError, match="cannot be empty"):
        _global_function_registry.register("", lambda x: x)


def test_custom_function_overrides_builtin():
    """Test that custom function can override built-in function."""
    
    @register_function("LENGTH")
    def custom_length(text: str) -> int:
        """Custom length that returns double the actual length."""
        return len(text) * 2
    
    func = get_function("LENGTH")
    assert func("hello") == 10  # 5 * 2


def test_multiple_custom_functions_in_query(adapter):
    """Test using multiple custom functions in one query."""

    @register_function("NUMWORDS")
    def word_count(text: str) -> int:
        return len(text.split())

    @register_function("NUMCHARS")
    def char_count(text: str) -> int:
        return len(text)

    nlql = NLQL(adapter=adapter)

    query = """
    SELECT CHUNK
    WHERE NUMWORDS(content) > 5 AND NUMCHARS(content) < 50
    """

    results = nlql.execute(query)

    # Should return 2 results (first: 9 words, 44 chars; second: 6 words, 44 chars)
    assert len(results) == 2
    assert "quick brown fox" in results[0].content or "Python" in results[0].content

