"""
LangGraph SQL Agent (Custom Implementation)

This example demonstrates building a SQL agent directly using LangGraph primitives
for deeper customization. This gives more control over the agent's behavior compared
to the higher-level LangChain agent.

Based on: https://docs.langchain.com/oss/python/langgraph/sql-agent
"""

import pathlib
import requests
from typing import Literal
from langchain.chat_models import init_chat_model
from langchain.messages import AIMessage
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import InMemorySaver
from pixie import pixie_app


def setup_database():
    """Download and setup the Chinook database if not already present."""
    url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
    local_path = pathlib.Path("Chinook.db")

    if local_path.exists():
        print(f"{local_path} already exists, skipping download.")
    else:
        print("Downloading Chinook database...")
        response = requests.get(url)
        if response.status_code == 200:
            local_path.write_bytes(response.content)
            print(f"File downloaded and saved as {local_path}")
        else:
            raise Exception(
                f"Failed to download the file. Status code: {response.status_code}"
            )

    return SQLDatabase.from_uri("sqlite:///Chinook.db")


def create_sql_graph(db: SQLDatabase, model):
    """Create a LangGraph-based SQL agent with custom workflow."""

    # Get tools from toolkit
    toolkit = SQLDatabaseToolkit(db=db, llm=model)
    tools = toolkit.get_tools()

    # Extract specific tools
    get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
    get_schema_node = ToolNode([get_schema_tool], name="get_schema")

    run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
    run_query_node = ToolNode([run_query_tool], name="run_query")

    list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")

    # Node: List tables (forced tool call)
    def list_tables(state: MessagesState):
        tool_call = {
            "name": "sql_db_list_tables",
            "args": {},
            "id": "list_tables_call",
            "type": "tool_call",
        }
        tool_call_message = AIMessage(content="", tool_calls=[tool_call])
        tool_message = list_tables_tool.invoke(tool_call)
        response = AIMessage(f"Available tables: {tool_message.content}")
        return {"messages": [tool_call_message, tool_message, response]}

    # Node: Force model to call get_schema
    def call_get_schema(state: MessagesState):
        llm_with_tools = model.bind_tools([get_schema_tool], tool_choice="any")
        response = llm_with_tools.invoke(state["messages"])
        return {"messages": [response]}

    # Node: Generate query
    generate_query_prompt = f"""
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {db.dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most 5 results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
"""

    def generate_query(state: MessagesState):
        system_message = {"role": "system", "content": generate_query_prompt}
        llm_with_tools = model.bind_tools([run_query_tool])
        response = llm_with_tools.invoke([system_message] + state["messages"])
        return {"messages": [response]}

    # Node: Check query
    check_query_prompt = f"""
You are a SQL expert with a strong attention to detail.
Double check the {db.dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes,
just reproduce the original query.

You will call the appropriate tool to execute the query after running this check.
"""

    def check_query(state: MessagesState):
        from langchain.messages import AIMessage as AI

        system_message = {"role": "system", "content": check_query_prompt}
        last_message = state["messages"][-1]
        # Only AIMessage has tool_calls
        if isinstance(last_message, AI) and last_message.tool_calls:
            tool_call = last_message.tool_calls[0]
            user_message = {"role": "user", "content": tool_call["args"]["query"]}
        else:
            # Fallback if no tool calls
            user_message = {"role": "user", "content": "Please check the query"}
        llm_with_tools = model.bind_tools([run_query_tool], tool_choice="any")
        response = llm_with_tools.invoke([system_message, user_message])
        if isinstance(last_message, AI):
            response.id = last_message.id
        return {"messages": [response]}

    # Conditional edge: Continue or end
    def should_continue(state: MessagesState) -> Literal["__end__", "check_query"]:
        from langchain.messages import AIMessage as AI

        messages = state["messages"]
        last_message = messages[-1]
        # Check if last message is AIMessage and has tool calls
        if isinstance(last_message, AI) and last_message.tool_calls:
            return "check_query"
        else:
            return "__end__"

    # Build graph
    builder = StateGraph(MessagesState)
    builder.add_node("list_tables", list_tables)
    builder.add_node("call_get_schema", call_get_schema)
    builder.add_node("get_schema", get_schema_node)
    builder.add_node("generate_query", generate_query)
    builder.add_node("check_query", check_query)
    builder.add_node("run_query", run_query_node)

    builder.add_edge(START, "list_tables")
    builder.add_edge("list_tables", "call_get_schema")
    builder.add_edge("call_get_schema", "get_schema")
    builder.add_edge("get_schema", "generate_query")
    builder.add_conditional_edges("generate_query", should_continue)
    builder.add_edge("check_query", "run_query")
    builder.add_edge("run_query", "generate_query")

    return builder.compile(checkpointer=InMemorySaver())


@pixie_app
async def langgraph_sql_agent(question: str) -> str:
    """Custom SQL agent built with LangGraph primitives.

    This agent has explicit control over the workflow:
    1. Lists all tables
    2. Gets schema for relevant tables
    3. Generates SQL query
    4. Checks query for errors
    5. Executes query
    6. Returns natural language answer

    Args:
        question: Natural language question about the database

    Returns:
        AI-generated answer based on SQL query results
    """
    # Setup database
    db = setup_database()

    # Initialize model
    model = init_chat_model("gpt-4o-mini", temperature=0)

    # Create graph
    graph = create_sql_graph(db, model)

    # Run the graph
    result = graph.invoke(
        {"messages": [{"role": "user", "content": question}]},  # type: ignore
        {"configurable": {"thread_id": "langgraph_sql"}},
    )

    # Return the final message
    return result["messages"][-1].content
