from typing import Annotated, Any, Literal, TypeAlias

from pydantic import (
    AliasChoices,
    AliasPath,
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
    computed_field,
)

from .helper import unix_ms


class _Base(BaseModel):
    model_config = ConfigDict(extra="allow")


class TimeInfo(_Base):
    """Represents start and end timestamps for an event."""

    start: int
    end: int

    @property
    def duration_ms(self) -> int:
        """Duration in milliseconds (end - start).

        Note: Can be negative in rare cases of clock skew between components.
        """
        return self.end - self.start


class CacheInfo(_Base):
    """Token cache usage information."""

    read: int
    write: int


class TokenUsage(_Base):
    """Detailed token usage for a step."""

    input: int
    output: int
    reasoning: int = 0
    cache: CacheInfo | None = None


class ToolStateInput(_Base):
    """Input parameters for a tool call. Fields are optional as they vary by tool."""

    command: str | None = None
    description: str | None = None


class ToolState(_Base):
    """The complete state of a tool's execution."""

    status: str
    input: ToolStateInput
    output: str | None = None
    title: str | None = None
    metadata: dict[str, Any] = Field(default_factory=dict)
    time: TimeInfo | None = None


class BasePart(_Base):
    """Base model for all event parts, containing common identifiers."""

    id: str = Field(repr=False, exclude=True)
    sessionID: str = Field(repr=False, exclude=True)
    messageID: str = Field(repr=False, exclude=True)


class StepStartPart(BasePart):
    """Part model for a 'step-start' event."""

    type: Literal["step-start"]
    snapshot: str | None = Field(default=None, repr=False, exclude=True)


class TextPart(BasePart):
    """Part model for a 'text' event."""

    type: Literal["text"]
    text: str
    time: TimeInfo | None = Field(default=None, repr=False, exclude=True)


class ToolUsePart(BasePart):
    """Part model for a 'tool_use' event."""

    type: Literal["tool"]
    callID: str = Field(repr=False, exclude=True)
    tool: str
    state: ToolState


class StepFinishPart(BasePart):
    """Part model for a 'step-finish' event."""

    type: Literal["step-finish"]
    snapshot: str | None = Field(default=None, repr=False, exclude=True)
    cost: float
    tokens: TokenUsage


class BaseEvent(_Base):
    """Base structure shared by all events from the CLI (excludes `type`)."""

    seq: int = Field(default=0)
    timestamp: int = Field(default_factory=unix_ms)
    sessionID: str = Field(default="", repr=False, exclude=True)


class StepStartEvent(BaseEvent):
    """Event representing the beginning of a processing step."""

    type: Literal["step_start"]
    part: StepStartPart = Field(repr=False, exclude=True)


class TextEvent(BaseEvent):
    """Event representing a chunk of text generated by InnerLoop."""

    type: Literal["text"]
    part: TextPart = Field(repr=False, exclude=True)

    @computed_field
    def text(self) -> str:
        return self.part.text


class ToolUseEvent(BaseEvent):
    """Event representing a tool invocation by InnerLoop (e.g., bash)."""

    type: Literal["tool_use"]
    part: ToolUsePart = Field(repr=False, exclude=True)

    @computed_field
    def output(self) -> str:
        return self.part.state.output or ""

    @computed_field
    def status(self) -> str:
        return self.part.state.status

    @computed_field
    def tool(self) -> str:
        return self.part.tool


class StepFinishEvent(BaseEvent):
    """Event representing the end of a processing step, with cost and token info."""

    type: Literal["step_finish"]
    part: StepFinishPart = Field(repr=False, exclude=True)

    @property
    def cost(self) -> float:
        return self.part.cost

    @property
    def tokens(self) -> TokenUsage:
        return self.part.tokens


class ErrorEvent(BaseEvent):
    """Typed error event emitted by the CLI.

    Error events differ from other events by using a flat payload rather than a
    nested `part` model. Only common fields are represented here; providers may
    omit optional fields.
    """

    type: Literal["error"]
    message: str = Field(
        default="Unknown error",
        validation_alias=AliasChoices(
            "message",
            AliasPath("error", "message"),
            AliasPath("error", "data", "message"),
        ),
    )
    code: str | None = Field(
        default=None,
        validation_alias=AliasChoices("code", AliasPath("error", "name")),
    )
    severity: str | None = Field(default="error")

    @property
    def error_message(self) -> str:
        """Alias for the top-level message field for consistency."""
        return self.message


class ValidationErrorEvent(BaseEvent):
    """Event emitted when structured output validation fails.

    This event is generated during streaming when the model's output
    cannot be parsed or validated against the specified response_format.
    """

    type: Literal["validation_error"]
    error_message: str
    attempt: int
    will_retry: bool


class ValidationSuccessEvent(BaseEvent):
    """Event emitted when structured output validation succeeds.

    This event contains the final parsed and validated structured output.
    """

    type: Literal["validation_success"]
    output: Any
    attempt: int


class RetryEvent(BaseEvent):
    """Event emitted when starting a new retry attempt after validation failure."""

    type: Literal["retry"]
    attempt: int
    previous_error: str


class PromptRenderedEvent(BaseEvent):
    """Event emitted when the prompt is constructed before sending to the LLM.

    This event contains the fully rendered prompt including:
    - User's original prompt
    - SlimSchema YAML (if structured output is requested)
    - Validation errors (if this is a retry attempt)
    - Permission instructions

    Useful for debugging, logging, and understanding exactly what prompt
    is sent to the LLM for each attempt.
    """

    type: Literal["prompt_rendered"]
    prompt: str
    attempt: int
    schema_yaml: str | None = None


class TimeoutEvent(BaseEvent):
    """Event emitted when a request times out waiting for LLM response.

    This event is generated when the idle timeout is reached - meaning
    no new events have been received within the specified timeout period.
    The response will include all events received up to the timeout.
    """

    type: Literal["timeout"]
    timeout_seconds: float
    message: str = "Request timed out waiting for response"


OpenCodeEvent: TypeAlias = (
    StepStartEvent | TextEvent | ToolUseEvent | StepFinishEvent | ErrorEvent
)

StreamEvent: TypeAlias = (
    OpenCodeEvent
    | ValidationErrorEvent
    | ValidationSuccessEvent
    | RetryEvent
    | PromptRenderedEvent
    | TimeoutEvent
)

EventUnion = Annotated[OpenCodeEvent, Field(discriminator="type")]

EventAdapter: TypeAdapter[OpenCodeEvent] = TypeAdapter(EventUnion)


__all__ = [
    "TimeInfo",
    "CacheInfo",
    "TokenUsage",
    "ToolStateInput",
    "ToolState",
    "BasePart",
    "StepStartPart",
    "TextPart",
    "ToolUsePart",
    "StepFinishPart",
    "BaseEvent",
    "StepStartEvent",
    "TextEvent",
    "ToolUseEvent",
    "StepFinishEvent",
    "ErrorEvent",
    "ValidationErrorEvent",
    "ValidationSuccessEvent",
    "RetryEvent",
    "PromptRenderedEvent",
    "TimeoutEvent",
    "OpenCodeEvent",
    "StreamEvent",
    "EventUnion",
    "EventAdapter",
]


def parse_event(raw: dict[str, Any]) -> OpenCodeEvent:
    """Parse an event dict into a typed event using the discriminator.

    This function intentionally avoids pre-normalization – the CLI is expected
    to emit the canonical shape. If variants arise in the future, they should be
    normalized at the source or handled by a dedicated adapter outside of this
    core model layer.
    """
    return EventAdapter.validate_python(raw)
