"""Semantic similarity utilities for NLQL.

This module provides functions for computing semantic similarity between texts
using embedding models.
"""

import math
from typing import Any

from nlql.registry.embedding import _global_embedding_registry
from nlql.text.units import TextUnit


def cosine_similarity(vec1: list[float], vec2: list[float]) -> float:
    """Compute cosine similarity between two vectors.
    
    Args:
        vec1: First vector
        vec2: Second vector
        
    Returns:
        Cosine similarity score between -1 and 1
        
    Raises:
        ValueError: If vectors have different lengths or are empty
    """
    if len(vec1) != len(vec2):
        raise ValueError(f"Vectors must have same length: {len(vec1)} != {len(vec2)}")
    
    if len(vec1) == 0:
        raise ValueError("Vectors cannot be empty")
    
    # Compute dot product
    dot_product = sum(a * b for a, b in zip(vec1, vec2))
    
    # Compute magnitudes
    magnitude1 = math.sqrt(sum(a * a for a in vec1))
    magnitude2 = math.sqrt(sum(b * b for b in vec2))
    
    # Avoid division by zero
    if magnitude1 == 0 or magnitude2 == 0:
        return 0.0
    
    return dot_product / (magnitude1 * magnitude2)


def compute_similarity_scores(
    units: list[TextUnit], 
    query_text: str,
    embedding_provider: Any | None = None
) -> list[TextUnit]:
    """Compute similarity scores for text units against a query.
    
    This function:
    1. Embeds the query text
    2. Embeds all unit contents (if not already embedded)
    3. Computes cosine similarity between query and each unit
    4. Stores similarity scores in unit.metadata["similarity"]
    
    Args:
        units: List of text units to score
        query_text: Query text to compare against
        embedding_provider: Optional custom embedding provider.
                          If None, uses the global registry.
    
    Returns:
        The same list of units with similarity scores added to metadata
        
    Note:
        This function modifies the units in-place by adding "similarity" to metadata.
        The similarity score is a float between 0 and 1 (cosine similarity normalized).
    """
    if not units:
        return units
    
    # Get embedding provider
    if embedding_provider is None:
        embedding_provider = _global_embedding_registry.get()
    
    # Embed query
    query_embedding = embedding_provider([query_text])[0]
    
    # Embed all unit contents
    unit_texts = [unit.content for unit in units]
    unit_embeddings = embedding_provider(unit_texts)
    
    # Compute similarities and store in metadata
    for unit, unit_embedding in zip(units, unit_embeddings):
        similarity = cosine_similarity(query_embedding, unit_embedding)
        
        # Normalize to [0, 1] range (cosine similarity is in [-1, 1])
        normalized_similarity = (similarity + 1) / 2
        
        # Store in metadata (user-visible field for WHERE/ORDER BY)
        unit.metadata["similarity"] = normalized_similarity
    
    return units


def extract_query_from_similar_to(where_condition: Any) -> str | None:
    """Extract query text from SIMILAR_TO operator in WHERE clause.
    
    This function traverses the WHERE AST to find SIMILAR_TO("query text")
    and extracts the query string.
    
    Args:
        where_condition: WHERE clause AST node
        
    Returns:
        Query text if SIMILAR_TO is found, None otherwise
        
    Example:
        WHERE SIMILAR_TO("AI agents") > 0.8
        Returns: "AI agents"
    """
    from nlql.ast.nodes import ComparisonExpr, LogicalExpr, OperatorCall
    
    if where_condition is None:
        return None
    
    # Check if this node is a SIMILAR_TO operator call
    if isinstance(where_condition, OperatorCall):
        if where_condition.operator == "SIMILAR_TO" and where_condition.args:
            # Extract the query text (first argument)
            from nlql.ast.nodes import Literal
            if isinstance(where_condition.args[0], Literal):
                return where_condition.args[0].value
    
    # Check if this is a comparison with SIMILAR_TO on one side
    elif isinstance(where_condition, ComparisonExpr):
        # Check left side
        query = extract_query_from_similar_to(where_condition.left)
        if query:
            return query
        # Check right side
        return extract_query_from_similar_to(where_condition.right)
    
    # Check if this is a logical expression (AND/OR)
    elif isinstance(where_condition, LogicalExpr):
        # Check all operands
        for operand in where_condition.operands:
            query = extract_query_from_similar_to(operand)
            if query:
                return query  # Return first found query
    
    return None

