"""Tests for instance-level registry support."""

import pytest

from nlql import NLQL
from nlql.adapters import MemoryAdapter


@pytest.fixture
def adapter():
    """Create a memory adapter with test data."""
    adapter = MemoryAdapter()
    adapter.add_text("The quick brown fox jumps over the lazy dog", {"id": 1})
    adapter.add_text("Python is a great programming language", {"id": 2})
    adapter.add_text("Machine learning is transforming AI", {"id": 3})
    return adapter


def test_instance_level_function_registration(adapter):
    """Test that functions can be registered to specific NLQL instances."""
    
    # Create two NLQL instances
    nlql1 = NLQL(adapter=adapter)
    nlql2 = NLQL(adapter=adapter)
    
    # Register different implementations of CUSTOM function to each instance
    @nlql1.register_function("CUSTOM")
    def custom_func1(text: str) -> int:
        """Count words."""
        return len(text.split())
    
    @nlql2.register_function("CUSTOM")
    def custom_func2(text: str) -> int:
        """Count characters."""
        return len(text)
    
    # Query using CUSTOM function in nlql1 (should count words)
    results1 = nlql1.execute("SELECT CHUNK WHERE CUSTOM(content) > 8")
    assert len(results1) == 1
    assert "quick brown fox" in results1[0].content
    
    # Query using CUSTOM function in nlql2 (should count characters)
    results2 = nlql2.execute("SELECT CHUNK WHERE CUSTOM(content) > 35")
    assert len(results2) == 2
    # Should return the two longer texts (44 and 39 characters)


def test_instance_level_operator_registration(adapter):
    """Test that operators can be registered to specific NLQL instances."""
    
    # Create two NLQL instances
    nlql1 = NLQL(adapter=adapter)
    nlql2 = NLQL(adapter=adapter)
    
    # Register different implementations of CUSTOM_OP to each instance
    @nlql1.register_operator("CUSTOM_OP")
    def custom_op1(text: str, keyword: str) -> bool:
        """Check if text starts with keyword."""
        return text.startswith(keyword)
    
    @nlql2.register_operator("CUSTOM_OP")
    def custom_op2(text: str, keyword: str) -> bool:
        """Check if text ends with keyword."""
        return text.endswith(keyword)
    
    # Query using CUSTOM_OP in nlql1 (should check startswith)
    results1 = nlql1.execute('SELECT CHUNK WHERE CUSTOM_OP(content, "Python")')
    assert len(results1) == 1
    assert "Python" in results1[0].content
    
    # Query using CUSTOM_OP in nlql2 (should check endswith)
    results2 = nlql2.execute('SELECT CHUNK WHERE CUSTOM_OP(content, "dog")')
    assert len(results2) == 1
    assert "lazy dog" in results2[0].content


def test_instance_level_embedding_provider(adapter):
    """Test that embedding providers can be registered to specific NLQL instances."""
    
    # Create two NLQL instances
    nlql1 = NLQL(adapter=adapter)
    nlql2 = NLQL(adapter=adapter)
    
    # Register different embedding providers to each instance
    @nlql1.register_embedding_provider
    def embedding1(texts: list[str]) -> list[list[float]]:
        """Simple embedding based on word count."""
        return [[len(text.split()) / 10.0, 0.5, 0.5] for text in texts]
    
    @nlql2.register_embedding_provider
    def embedding2(texts: list[str]) -> list[list[float]]:
        """Simple embedding based on character count."""
        return [[len(text) / 50.0, 0.5, 0.5] for text in texts]
    
    # Both should work with SIMILAR_TO
    results1 = nlql1.execute('SELECT CHUNK WHERE SIMILAR_TO("test") > 0.3')
    results2 = nlql2.execute('SELECT CHUNK WHERE SIMILAR_TO("test") > 0.3')
    
    # Both should return results (exact results depend on similarity calculation)
    assert len(results1) >= 0
    assert len(results2) >= 0


def test_instance_registry_does_not_affect_global(adapter):
    """Test that instance-level registrations don't affect global registry."""
    
    from nlql.registry.functions import get_function
    
    # Create NLQL instance and register a function
    nlql = NLQL(adapter=adapter)
    
    @nlql.register_function("INSTANCE_ONLY")
    def instance_func(text: str) -> int:
        return len(text)
    
    # Function should work in the instance
    results = nlql.execute("SELECT CHUNK WHERE INSTANCE_ONLY(content) > 40")
    assert len(results) >= 0
    
    # But should NOT be in global registry
    assert get_function("INSTANCE_ONLY") is None


def test_instance_registry_precedence_over_global(adapter):
    """Test that instance-level registrations take precedence over global."""
    
    from nlql.registry.functions import register_function
    
    # Register a global function
    @register_function("PRECEDENCE_TEST")
    def global_func(text: str) -> int:
        return 100  # Always return 100
    
    # Create NLQL instance and override with instance-level function
    nlql = NLQL(adapter=adapter)
    
    @nlql.register_function("PRECEDENCE_TEST")
    def instance_func(text: str) -> int:
        return len(text.split())  # Return word count
    
    # Query should use instance-level function (word count)
    results = nlql.execute("SELECT CHUNK WHERE PRECEDENCE_TEST(content) > 8")
    assert len(results) == 1
    assert "quick brown fox" in results[0].content
    
    # If it used global function (always 100), all results would match


def test_multiple_instances_independent(adapter):
    """Test that multiple instances with different registrations are independent."""
    
    nlql1 = NLQL(adapter=adapter)
    nlql2 = NLQL(adapter=adapter)
    nlql3 = NLQL(adapter=adapter)
    
    # Register different functions to each
    @nlql1.register_function("MULTI")
    def func1(text: str) -> int:
        return 1
    
    @nlql2.register_function("MULTI")
    def func2(text: str) -> int:
        return 2
    
    @nlql3.register_function("MULTI")
    def func3(text: str) -> int:
        return 3
    
    # Each instance should use its own function
    results1 = nlql1.execute("SELECT CHUNK WHERE MULTI(content) == 1")
    results2 = nlql2.execute("SELECT CHUNK WHERE MULTI(content) == 2")
    results3 = nlql3.execute("SELECT CHUNK WHERE MULTI(content) == 3")
    
    # All should return all results (since their functions always return the expected value)
    assert len(results1) == 3
    assert len(results2) == 3
    assert len(results3) == 3

