from __future__ import annotations

import json
import os
import re
from dataclasses import dataclass
from typing import Any, TypeVar

import ijson
from jsonschema import Draft7Validator
from pydantic import BaseModel, TypeAdapter, ValidationError

from .events import ErrorEvent, OpenCodeEvent, TimeInfo
from .helper import unix_ms
from .logging_config import get_logger
from .permissions import Permission
from .request import Request
from .response import Response
from .usage import compute_usage

T = TypeVar("T", bound=BaseModel)
logger = get_logger("structured")

JSON_TAG_RE = re.compile(
    r"<json>\s*(.*?)\s*</json>", re.DOTALL | re.IGNORECASE
)


@dataclass
class ValidationResult:
    """Result of disk-based JSON validation."""

    valid: bool
    errors: list[str]


def _get_json_schema(
    response_format: type[BaseModel] | TypeAdapter[Any],
) -> tuple[dict[str, Any], bool]:
    """Extract JSON Schema from BaseModel or TypeAdapter.

    Args:
        response_format: Pydantic BaseModel class or TypeAdapter instance

    Returns:
        Tuple of (json_schema_dict, is_array_schema)
    """
    # Check if it's a TypeAdapter
    if isinstance(response_format, TypeAdapter):
        # Try to detect if this is an array schema
        try:
            # Get the core schema or JSON schema
            json_schema = response_format.json_schema()
            is_array = json_schema.get("type") == "array"
            return json_schema, is_array
        except Exception as e:
            logger.warning(f"Failed to get JSON schema from TypeAdapter: {e}")
            # Fallback: assume it's not an array
            return {"type": "object"}, False

    # Otherwise, it's a BaseModel class
    json_schema = response_format.model_json_schema()
    return json_schema, False


def _format_validation_error(error: Any) -> str:
    """Format JSON Schema validation error for LLM consumption.

    Args:
        error: jsonschema.ValidationError instance

    Returns:
        Formatted error string like "Error at 'path.to.field': message"
    """
    # Build JSON path
    path = ".".join(str(p) for p in error.path) if error.path else "root"

    # Get error message
    message = error.message

    return f"Error at '{path}': {message}"


def _validate_object_whole(
    output_file: str, schema: dict[str, Any]
) -> ValidationResult:
    """Validate entire object file (loads into memory).

    Args:
        output_file: Path to JSON output file
        schema: JSON Schema dictionary

    Returns:
        ValidationResult with validation status and errors
    """
    try:
        with open(output_file) as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        return ValidationResult(
            valid=False, errors=[f"Invalid JSON in output file: {e}"]
        )
    except OSError as e:
        return ValidationResult(
            valid=False, errors=[f"Failed to read output file: {e}"]
        )

    # Validate against schema
    validator = Draft7Validator(schema)
    errors = list(validator.iter_errors(data))

    if not errors:
        return ValidationResult(valid=True, errors=[])

    # Format errors for LLM retry prompts
    formatted_errors = [_format_validation_error(e) for e in errors]
    return ValidationResult(valid=False, errors=formatted_errors)


def _validate_array_streaming(
    output_file: str, schema: dict[str, Any]
) -> ValidationResult:
    """Stream-validate array items without loading entire file.

    Uses ijson to incrementally parse and validate each array item.
    Memory usage: size of ONE item, not entire array.

    Args:
        output_file: Path to JSON output file
        schema: JSON Schema dictionary (must have 'items' field)

    Returns:
        ValidationResult with validation status and errors
    """
    # Extract item schema from array schema
    item_schema = schema.get("items", {})
    if not item_schema:
        return ValidationResult(
            valid=False, errors=["Array schema missing 'items' definition"]
        )

    # If the item schema has $ref, we need to resolve it using the full schema
    # Create a validator with the full schema first, then extract item validator
    if "$ref" in item_schema:
        # Use full schema as the root, but validate each item separately
        # We need to build a schema that includes $defs from parent
        full_item_schema = {
            **item_schema,
            "$defs": schema.get("$defs", {}),
        }
        item_validator = Draft7Validator(full_item_schema)
    else:
        item_validator = Draft7Validator(item_schema)
    all_errors: list[str] = []
    item_count = 0
    # Array-level constraints
    min_items = schema.get("minItems")
    max_items = schema.get("maxItems")
    unique_items = bool(schema.get("uniqueItems"))
    seen_canon: set[str] | None = set() if unique_items else None

    try:
        with open(output_file, "rb") as f:
            # Stream each item in top-level array
            for idx, item in enumerate(ijson.items(f, "item")):
                item_count += 1

                # Validate this one item (only this item in memory)
                errors = list(item_validator.iter_errors(item))

                if errors:
                    for err in errors:
                        # Prefix error path with array index
                        path = f"[{idx}]"
                        if err.path:
                            path += "." + ".".join(str(p) for p in err.path)
                        all_errors.append(f"Error at '{path}': {err.message}")

                # Enforce uniqueness as we stream (deep-equality via canonical JSON)
                if seen_canon is not None:
                    try:
                        canon = json.dumps(
                            item, sort_keys=True, separators=(",", ":")
                        )
                    except Exception:
                        # Fallback to str(); if this fails to detect dupes, item validator should catch
                        canon = str(item)
                    if canon in seen_canon:
                        all_errors.append(
                            f"Error at '[{idx}]': duplicate item violates uniqueItems"
                        )
                    else:
                        seen_canon.add(canon)

                # Item dropped from memory here, move to next
    except (json.JSONDecodeError, ijson.JSONError) as e:
        return ValidationResult(
            valid=False, errors=[f"Invalid JSON in output file: {e}"]
        )
    except OSError as e:
        return ValidationResult(
            valid=False, errors=[f"Failed to read output file: {e}"]
        )
    except Exception as e:
        return ValidationResult(valid=False, errors=[f"Validation error: {e}"])

    # Enforce array-level size constraints after streaming
    if isinstance(min_items, int) and item_count < min_items:
        all_errors.append(
            f"Array has {item_count} item(s), which is fewer than minItems={min_items}"
        )
    if isinstance(max_items, int) and item_count > max_items:
        all_errors.append(
            f"Array has {item_count} item(s), which exceeds maxItems={max_items}"
        )

    if all_errors:
        return ValidationResult(valid=False, errors=all_errors)

    if item_count == 0:
        return ValidationResult(
            valid=False, errors=["Array is empty or malformed"]
        )

    logger.debug(f"Successfully validated {item_count} array items")
    return ValidationResult(valid=True, errors=[])


def _validate_json_file_on_disk(
    output_file: str, schema_file: str, is_array: bool = False
) -> ValidationResult:
    """Validate JSON file against JSON Schema without loading into memory.

    Supports two modes:
    - Object mode (is_array=False): Load and validate entire object
    - Array mode (is_array=True): Stream-validate each item (memory efficient)

    Args:
        output_file: Path to output.json
        schema_file: Path to .output.json.schema
        is_array: True for array schemas (enables streaming validation)

    Returns:
        ValidationResult with validation status and formatted errors
    """
    # Check if output file exists
    if not os.path.exists(output_file):
        return ValidationResult(
            valid=False,
            errors=["Output file not found. LLM must write to output file."],
        )

    # Load schema (small, always fits in memory)
    try:
        with open(schema_file) as f:
            schema = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        return ValidationResult(
            valid=False, errors=[f"Failed to load schema file: {e}"]
        )

    # Choose validation mode
    if is_array:
        return _validate_array_streaming(output_file, schema)
    else:
        return _validate_object_whole(output_file, schema)


def _compose_disk_prompt(
    prompt: str,
    output_file: str,
    json_schema: dict[str, Any],
    attempt: int,
) -> str:
    """Compose prompt with JSON Schema instructions for disk-based validation.

    Args:
        prompt: Original user prompt
        output_file: Path where LLM should write output
        json_schema: JSON Schema dict
        attempt: Current attempt number (1-indexed)

    Returns:
        Prompt with schema instructions
    """
    schema_str = json.dumps(json_schema, indent=2)

    if attempt == 1:
        instructions = (
            f"\n\nYou must write your response as valid JSON to: {output_file}\n"
            f"The JSON must match this schema:\n\n"
            f"<schema>\n{schema_str}\n</schema>\n\n"
            f"Write the JSON directly to the file. Do not return JSON in your response text."
        )
    else:
        instructions = (
            f"\n\n[Retry attempt {attempt}]\n"
            f"Your previous output failed validation. Please write corrected JSON to: {output_file}\n"
            f"The JSON must match this schema:\n\n"
            f"<schema>\n{schema_str}\n</schema>\n"
        )

    return prompt + instructions


def _compose_disk_retry(
    base_prompt: str,
    validation_errors: list[str],
    json_schema: dict[str, Any],
) -> str:
    """Compose retry prompt with validation errors for disk mode.

    Args:
        base_prompt: Original prompt with schema instructions
        validation_errors: List of formatted validation errors
        json_schema: JSON Schema dict

    Returns:
        Retry prompt with error details
    """
    errors_str = "\n".join(f"  - {err}" for err in validation_errors)

    return (
        f"{base_prompt}\n\n"
        f"Your previous JSON failed validation with these errors:\n{errors_str}\n\n"
        f"Use the validation errors as the source of truth. "
        f"Correct the JSON and write it to the output file."
    )


def _extract_json_snippet(text: str) -> str:
    """Extract JSON only from <json>...</json> tags.

    Keeping a single, explicit protocol makes behavior predictable. If the
    model does not return tags, the caller should reprompt.
    """
    m = JSON_TAG_RE.search(text)
    if not m:
        raise ValueError("No <json>...</json> block found in the response.")
    return m.group(1).strip()


def _format_name(fmt: type[BaseModel] | TypeAdapter[Any]) -> str:
    if isinstance(fmt, TypeAdapter):
        return "TypeAdapter"
    return getattr(fmt, "__name__", str(fmt))


def _format_schema(fmt: type[BaseModel] | TypeAdapter[Any]) -> dict[str, Any]:
    if isinstance(fmt, TypeAdapter):
        return fmt.json_schema()
    return fmt.model_json_schema()


def _validate_json(fmt: type[BaseModel] | TypeAdapter[Any], text: str) -> Any:
    try:
        if isinstance(fmt, TypeAdapter):
            return fmt.validate_json(text)
        return fmt.model_validate_json(text)
    except ValidationError as e:
        raise ValueError(str(e)) from e


def _dump_json(value: Any) -> str:
    if isinstance(value, BaseModel):
        return value.model_dump_json(indent=2)
    return json.dumps(value, indent=2)


def _extract_and_parse_json(
    text: str, resp_format: type[BaseModel] | TypeAdapter[Any]
) -> Any:
    """Extract the best JSON snippet then validate against the schema/adapter."""
    snippet = _extract_json_snippet(text)
    return _validate_json(resp_format, snippet)


def _resolve_output_path(
    workdir: str | None, session: str | None, override: str | None
) -> str:
    wd = workdir or os.getcwd()
    if override:
        target = override
    else:
        sid = session or "default"
        safe = "".join(c if c.isalnum() else "_" for c in sid)
        target = os.path.join(
            wd, ".innerloop", f"structured_{safe}_{unix_ms()}.json"
        )
    abs_path = os.path.abspath(target)
    abs_wd = os.path.abspath(wd)
    if not abs_wd.endswith(os.sep):
        abs_wd += os.sep
    if not abs_path.startswith(abs_wd):
        raise ValueError(
            f"Structured output file must be under workdir. File: {abs_path}, Workdir: {abs_wd.rstrip(os.sep)}"
        )
    os.makedirs(os.path.dirname(abs_path), exist_ok=True)
    return abs_path


def _can_write(perms: Permission) -> bool:
    if perms.edit == perms.ALLOW:
        return True
    if perms.bash == perms.ALLOW:
        return True
    if isinstance(perms.bash, dict) and any(
        v == perms.ALLOW for v in perms.bash.values()
    ):
        return True
    return False


def _compose_prompt(
    prompt: str,
    resp_format: type[BaseModel] | TypeAdapter[Any],
    abs_path: str,
    can_write: bool,
) -> str:
    schema = json.dumps(_format_schema(resp_format), indent=2)
    if can_write:
        instr = (
            f"Write a valid JSON matching this schema to: {abs_path}.\n"
            "If you cannot write files, return ONLY JSON inside <json>...</json> tags."
        )
    else:
        instr = (
            "Return ONLY JSON inside <json>...</json> tags matching the schema.\n"
            "No extra text."
        )
    return f"{prompt}\n\n{instr}\nSchema:\n<schema>\n{schema}\n</schema>"


def _compose_retry(base_prompt: str, err: str, model_cls: Any) -> str:
    """Compose a neutral, validation-driven retry prompt.

    - Keep the original task context (base_prompt)
    - Surface the exact validation error so the model can correct
    - Instruct to return ONLY JSON in <json> tags (no prose)
    - Do not bake in test-specific hints or field names
    """
    return (
        f"{base_prompt}\n\n"
        "Your previous JSON failed validation.\n"
        f"Error: {err}.\n"
        "Use the validation error as the source of truth.\n"
        "Return ONLY corrected JSON in <json>...</json> tags that validates."
    )


async def _attempt(
    request: Request,
    *,
    prompt: str,
    abs_path: str,
    resp_format: type[BaseModel] | TypeAdapter[Any],
    session: str | None,
    timeout: float | None = None,
) -> tuple[Response[Any], Any | None, str | None]:
    from .invoke import async_invoke  # local import to avoid cycles

    resp = await async_invoke(
        Request(
            model=request.model,
            prompt=prompt,
            permission=request.permission,
            providers=request.providers,
            mcp=request.mcp,
            response_format=None,
            session=session,
            workdir=request.workdir,
        ),
        timeout=timeout,
    )

    # File mode if file exists
    if os.path.exists(abs_path):
        logger.debug(f"Validating structured output from file: {abs_path}")
        try:
            with open(abs_path) as f:
                content = f.read()
            out = _validate_json(resp_format, content)
            logger.debug("File mode validation successful")
            return resp, out, None
        except (ValidationError, ValueError, OSError) as e:
            logger.debug(f"File mode validation failed: {str(e)}")
            return resp, None, str(e)

    # Fallback: extract from textual output
    logger.debug(
        "File not found, attempting to extract JSON from response text"
    )
    try:
        out = _extract_and_parse_json(str(resp.output), resp_format)
        with open(abs_path, "w") as f:
            f.write(_dump_json(out))
        logger.debug("Tag extraction and validation successful")
        return resp, out, None
    except (ValidationError, ValueError) as e:
        logger.debug(f"Tag extraction validation failed: {str(e)}")
        return resp, None, str(e)


async def _disk_validation_mode(
    request: Request, *, max_retries: int = 3, timeout: float | None = None
) -> Response[Any]:
    """Disk-based validation: write schema to disk, validate on disk, cleanup.

    Supports both BaseModel (objects) and TypeAdapter (arrays).

    Flow:
    1. Detect if response_format is BaseModel or TypeAdapter
    2. Generate JSON Schema accordingly
    3. Write schema to temp file next to output
    4. Tell LLM to write output to file
    5. Validate on disk (streaming for arrays, full for objects)
    6. Retry on validation errors
    7. Cleanup schema file
    8. Return Response with output=None (path only)

    Args:
        request: The invocation request with response_format set
        max_retries: Maximum number of validation retry attempts

    Returns:
        Response with output=None and structured_output_file set
    """
    from .invoke import async_invoke  # local import to avoid cycles

    if request.response_format is None:
        raise ValueError(
            "Disk validation mode requires response_format to be set"
        )

    # Resolve output file path
    abs_path = _resolve_output_path(
        request.workdir, request.session, request.output_file
    )

    # Generate JSON Schema and detect if array
    json_schema, is_array = _get_json_schema(request.response_format)

    # Write schema to temp file next to output
    schema_file = abs_path + ".schema.json"
    with open(schema_file, "w") as f:
        json.dump(json_schema, f, indent=2)

    logger.info(
        f"Starting disk validation mode (max_retries={max_retries}, "
        f"is_array={is_array}, model={request.model})"
    )
    logger.debug(f"Output file: {abs_path}")
    logger.debug(f"Schema file: {schema_file}")

    attempts = 0
    total_events: list[OpenCodeEvent] = []
    final_session: str | None = request.session
    wall_start = unix_ms()
    validation_result: ValidationResult | None = None

    try:
        while attempts < max_retries:
            attempts += 1
            logger.info(f"Attempt {attempts}/{max_retries}")

            # Compose prompt with schema instructions
            if attempts == 1:
                prompt = _compose_disk_prompt(
                    request.prompt, abs_path, json_schema, attempts
                )
            else:
                # Use errors from previous attempt
                prev_errors = (
                    validation_result.errors
                    if validation_result
                    else ["Unknown validation error"]
                )
                prompt = _compose_disk_retry(
                    _compose_disk_prompt(
                        request.prompt, abs_path, json_schema, attempts
                    ),
                    prev_errors,
                    json_schema,
                )

            if attempts > 1 and validation_result:
                logger.warning(
                    f"Retrying disk validation (attempt {attempts}/{max_retries}) - "
                    f"Previous errors: {validation_result.errors[:3]}"  # Log first 3 errors
                )

            # Run without structured output (just write to file)
            resp = await async_invoke(
                Request(
                    model=request.model,
                    prompt=prompt,
                    permission=request.permission,
                    providers=request.providers,
                    mcp=request.mcp,
                    session=final_session,
                    workdir=request.workdir,
                ),
                timeout=timeout,
            )
            total_events.extend(resp.events)
            final_session = resp.session_id

            # Validate on disk (NEVER load output into memory)
            validation_result = _validate_json_file_on_disk(
                abs_path, schema_file, is_array=is_array
            )

            if validation_result.valid:
                logger.info(
                    f"Disk validation successful on attempt {attempts}"
                )
                break

            # Validation failed - log errors and prepare for retry
            logger.warning(
                f"Disk validation failed on attempt {attempts}: "
                f"{validation_result.errors[:3]}"  # Log first 3 errors
            )
            total_events.append(
                ErrorEvent(
                    timestamp=unix_ms(),
                    sessionID=final_session or "",
                    type="error",
                    message="; ".join(validation_result.errors[:5]),
                    code=None,
                    severity="error",
                )
            )

        # Check if validation succeeded
        if validation_result is None or not validation_result.valid:
            final_errors = (
                validation_result.errors
                if validation_result
                else ["Validation never completed"]
            )
            logger.error(
                f"Disk validation failed after {attempts} attempts. "
                f"Final errors: {final_errors[:3]}"
            )
            raise RuntimeError(
                f"Disk validation failed after {attempts} attempts. "
                f"Expected file: {abs_path}\nErrors: {final_errors}"
            )

        # Success! Return without loading output
        wall_end = unix_ms()
        logger.info(
            f"Disk validation completed successfully after {attempts} attempt(s)"
        )

        out_resp = Response(
            session_id=final_session or "",
            input=request.prompt,
            output=None,  # CRITICAL: Never load into memory
            structured_output_file=abs_path,
            validated=True,
            validation_errors=None,
            events=total_events,
            attempts=attempts,
            time=TimeInfo(start=wall_start, end=wall_end),
        )
        out_resp.usage = compute_usage(total_events)
        return out_resp

    finally:
        # ALWAYS cleanup schema file
        if os.path.exists(schema_file):
            try:
                os.remove(schema_file)
                logger.debug(f"Cleaned up schema file: {schema_file}")
            except OSError as e:
                logger.warning(f"Failed to cleanup schema file: {e}")


async def invoke_structured(
    request: Request,
    *,
    max_retries: int = 3,
    timeout: float | None = None,
) -> Response[Any]:
    """Structured output invocation with automatic mode selection.

    Routes to either:
    - Disk validation mode (if write permissions enabled) - memory efficient
    - Memory validation mode (if write permissions disabled) - backward compatible

    Args:
        request: Invocation request with optional response_format
        max_retries: Maximum number of validation retry attempts

    Returns:
        Response with structured output (in memory or on disk)
    """
    if request.response_format is None:
        from .invoke import async_invoke  # local import

        return await async_invoke(request, timeout=timeout)

    # Determine validation mode based on write permissions
    can_write = _can_write(request.permission)

    if can_write:
        # NEW: Disk-only validation mode (memory efficient)
        logger.info("Using disk validation mode (write permissions enabled)")
        return await _disk_validation_mode(
            request, max_retries=max_retries, timeout=timeout
        )

    # EXISTING: Memory validation mode (backward compatible)
    logger.info("Using memory validation mode (write permissions disabled)")

    abs_path = _resolve_output_path(
        request.workdir, request.session, request.output_file
    )
    base_prompt = _compose_prompt(
        request.prompt, request.response_format, abs_path, can_write
    )

    logger.info(
        f"Starting structured invocation (max_retries={max_retries}, model={request.model})"
    )
    logger.debug(f"Output file: {abs_path}")
    logger.debug(f"Response format: {_format_name(request.response_format)}")

    attempts = 0
    total_events: list[OpenCodeEvent] = []
    final_session: str | None = request.session
    final_output: BaseModel | None = None
    wall_start = unix_ms()
    last_err = "invalid"

    while attempts < max_retries:
        attempts += 1
        logger.info(f"Attempt {attempts}/{max_retries}")

        prompt = (
            base_prompt
            if attempts == 1
            else _compose_retry(base_prompt, last_err, request.response_format)
        )

        if attempts > 1:
            logger.warning(
                f"Retrying structured output validation (attempt {attempts}/{max_retries}) - "
                f"Previous validation error: {last_err}"
            )

        resp, out, err = await _attempt(
            request,
            prompt=prompt,
            abs_path=abs_path,
            resp_format=request.response_format,
            session=final_session,
            timeout=timeout,
        )
        total_events.extend(resp.events)
        final_session = resp.session_id

        if out is not None:
            final_output = out
            logger.info(
                f"Structured output validation successful on attempt {attempts}"
            )
            break

        last_err = err or last_err
        logger.warning(f"Validation failed on attempt {attempts}: {last_err}")

        total_events.append(
            ErrorEvent(
                timestamp=unix_ms(),
                sessionID=final_session or "",
                type="error",
                message=last_err,
                code=None,
                severity="error",
            )
        )

    if final_output is None:
        logger.error(
            f"Structured output validation failed after {attempts} attempts. "
            f"Final error: {last_err}"
        )
        raise RuntimeError(
            f"Structured output validation failed after {attempts} attempts. Expected file: {abs_path}\nError: {last_err}"
        )

    wall_end = unix_ms()
    logger.info(
        f"Structured invocation completed successfully after {attempts} attempt(s)"
    )
    logger.debug(
        f"Total events: {len(total_events)}, Session: {final_session}"
    )

    out_resp = Response(
        session_id=final_session or "",
        input=request.prompt,
        output=final_output,
        structured_output_file=abs_path,
        events=total_events,
        attempts=attempts,
        time=TimeInfo(start=wall_start, end=wall_end),
    )
    out_resp.usage = compute_usage(total_events)
    return out_resp
