"""Text splitting utilities for different granularities."""

from collections.abc import Callable

from nlql.text.units import Sentence, TextUnit

# Type alias for splitter functions
SplitterFunc = Callable[[str], list[str]]


def default_sentence_splitter(text: str) -> list[str]:
    """Default sentence splitter using simple heuristics.

    This is a basic implementation. Users can register custom splitters
    for better language-specific or domain-specific splitting.

    Args:
        text: Input text to split

    Returns:
        List of sentence strings
    """
    # Simple sentence splitting on common terminators
    # This is intentionally basic - users should register better splitters
    import re

    # Split on sentence terminators followed by whitespace or end of string
    sentences = re.split(r"(?<=[.!?])\s+", text.strip())
    return [s.strip() for s in sentences if s.strip()]


def split_into_sentences(text: str, splitter: SplitterFunc | None = None) -> list[Sentence]:
    """Split text into Sentence units.

    Args:
        text: Input text
        splitter: Optional custom splitter function

    Returns:
        List of Sentence objects
    """
    if splitter is None:
        splitter = default_sentence_splitter

    sentence_texts = splitter(text)
    return [
        Sentence(content=sent, sentence_index=idx) for idx, sent in enumerate(sentence_texts)
    ]


def create_span(
    units: list[TextUnit], target_index: int, window_size: int
) -> tuple[list[TextUnit], TextUnit, list[TextUnit]]:
    """Create a span with context window around a target unit.

    Args:
        units: List of text units
        target_index: Index of the target unit
        window_size: Number of units to include before and after

    Returns:
        Tuple of (context_before, target, context_after)
    """
    start_idx = max(0, target_index - window_size)
    end_idx = min(len(units), target_index + window_size + 1)

    context_before = units[start_idx:target_index]
    target = units[target_index]
    context_after = units[target_index + 1 : end_idx]

    return context_before, target, context_after

