"""Tests for custom embedding provider registration and usage."""

import pytest

from nlql import NLQL
from nlql.adapters import MemoryAdapter
from nlql.registry.embedding import register_embedding_provider, get_embedding_provider


@pytest.fixture
def adapter():
    """Create a memory adapter with test data."""
    adapter = MemoryAdapter()
    
    adapter.add_text(
        "Machine learning is a subset of artificial intelligence",
        {"category": "AI"}
    )
    
    adapter.add_text(
        "Python is a popular programming language",
        {"category": "Programming"}
    )
    
    adapter.add_text(
        "Natural language processing enables computers to understand text",
        {"category": "NLP"}
    )
    
    return adapter


def test_register_custom_embedding_provider_basic():
    """Test basic custom embedding provider registration."""

    def mock_embedding_provider(texts: list[str]) -> list[list[float]]:
        """Return simple mock embeddings based on text length."""
        embeddings = []
        for text in texts:
            length = len(text)
            embeddings.append([length / 100.0, 1.0 - length / 100.0, 0.5])
        return embeddings

    register_embedding_provider(mock_embedding_provider)

    # Verify provider is registered
    registered = get_embedding_provider()
    assert registered is mock_embedding_provider

    # Test the provider works
    result = registered(["hello", "world"])
    assert len(result) == 2
    assert len(result[0]) == 3


def test_custom_embedding_provider_with_query():
    """Test using custom embedding provider in a query."""

    def simple_embedding_provider(texts: list[str]) -> list[list[float]]:
        """Create embeddings based on word count and character count."""
        embeddings = []
        for text in texts:
            words = len(text.split())
            chars = len(text)
            # Normalize to [0, 1] range
            embeddings.append([words / 20.0, chars / 100.0, 0.5])
        return embeddings

    register_embedding_provider(simple_embedding_provider)
    
    adapter = MemoryAdapter()
    adapter.add_text("Short text", {"id": 1})
    adapter.add_text("This is a much longer text with many more words", {"id": 2})
    
    nlql = NLQL(adapter=adapter)
    
    # Query with SIMILAR_TO should use custom provider
    query = """
    SELECT CHUNK
    WHERE SIMILAR_TO("medium length query text") > 0.5
    ORDER BY SIMILARITY DESC
    """
    
    results = nlql.execute(query)
    
    # Should return results (exact count depends on similarity threshold)
    assert len(results) >= 0
    # All results should have similarity scores
    for result in results:
        assert "similarity" in result.metadata
        assert isinstance(result.metadata["similarity"], float)


def test_custom_embedding_provider_dimension_consistency():
    """Test that custom embedding provider returns consistent dimensions."""

    def fixed_dim_embedding_provider(texts: list[str]) -> list[list[float]]:
        """Return fixed 5-dimensional embeddings."""
        embeddings = []
        for text in texts:
            # Use hash of text to generate deterministic but varied embeddings
            hash_val = hash(text)
            embeddings.append([
                (hash_val % 100) / 100.0,
                ((hash_val // 100) % 100) / 100.0,
                ((hash_val // 10000) % 100) / 100.0,
                0.5,
                0.5
            ])
        return embeddings

    register_embedding_provider(fixed_dim_embedding_provider)

    # Test multiple texts
    texts = ["short", "medium length text", "a very long text with many words and characters"]
    embeddings = fixed_dim_embedding_provider(texts)

    # All embeddings should have the same dimension
    assert all(len(emb) == 5 for emb in embeddings)

    # Embeddings should be different for different texts
    assert embeddings[0] != embeddings[1]
    assert embeddings[1] != embeddings[2]


def test_custom_embedding_provider_with_character_features():
    """Test custom embedding provider that uses character-based features."""

    def char_based_embedding_provider(texts: list[str]) -> list[list[float]]:
        """Create embeddings based on character frequencies."""
        embeddings = []
        for text in texts:
            # Count character frequencies (a-z)
            char_counts = [0.0] * 26
            for char in text.lower():
                if 'a' <= char <= 'z':
                    char_counts[ord(char) - ord('a')] += 1

            # Normalize
            total = sum(char_counts)
            if total > 0:
                char_counts = [count / total for count in char_counts]

            # Return first 10 dimensions
            embeddings.append(char_counts[:10])
        return embeddings

    register_embedding_provider(char_based_embedding_provider)

    # Test embedding
    embeddings = char_based_embedding_provider(["hello world"])

    assert len(embeddings) == 1
    assert len(embeddings[0]) == 10
    assert 0 <= sum(embeddings[0]) <= 1.0  # Normalized


def test_register_embedding_provider_as_decorator():
    """Test that register_embedding_provider can be used as a decorator."""

    # Use as decorator (without parentheses)
    @register_embedding_provider
    def decorator_embedding_provider(texts: list[str]) -> list[list[float]]:
        """Embedding provider registered using decorator syntax."""
        return [[1.0, 2.0, 3.0] for _ in texts]

    # Verify provider is registered
    registered = get_embedding_provider()
    assert registered is decorator_embedding_provider

    # Test the provider works
    result = registered(["test"])
    assert len(result) == 1
    assert result[0] == [1.0, 2.0, 3.0]

