# File: tests/test_agent_chains.py
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.language_models import BaseChatModel

# We need to mock the LLM responses
from langchain_core.messages import AIMessage

from commitai.agent import SummarizationMiddleware, TodoMiddleware, create_commit_agent
from commitai.chains import CommitState


@pytest.fixture
def mock_llm():
    llm = MagicMock(spec=BaseChatModel)
    llm.invoke.return_value = AIMessage(content="Mocked LLM Response")
    return llm


def test_summarization_middleware(mock_llm):
    middleware = SummarizationMiddleware(mock_llm)
    state: CommitState = {
        "diff": "some diff",
        "explanation": "fix stuff",
        "summary": None,
        "todos": None,
    }

    # Mock chain invoke
    with patch("langchain_core.prompts.ChatPromptTemplate.from_messages"):
        mock_chain = MagicMock()
        mock_chain.invoke.return_value = "Summarized diff"
        # We need to dig deep to mock the internal chain construction
        # if we want to test logic, but for coverage we just need to run it.
        # Actually SummarizationMiddleware.__call__ creates a chain.
        # Let's just mock the invoke of the created chain.

        # Instead of deep mocking, let's rely on the passed LLM mock to return something
        mock_llm.invoke.return_value = AIMessage(content="Summarized diff")

        result_state = middleware.process(state)
        assert result_state["summary"] is not None
        # It might be "Summarized diff" or whatever the parser returns.
        # CommitAI uses StrOutputParser, so AIMessage.content string.


def test_todo_scanner_middleware(mock_llm):
    middleware = TodoMiddleware()
    state: CommitState = {
        "diff": "+ TODO: fix this",
        "explanation": "",
        "summary": "",
        "todos": None,
    }

    mock_llm.invoke.return_value = AIMessage(content="- Fix this")

    result_state = middleware.process(state)
    assert result_state["todos"] is not None


def test_create_commit_agent(mock_llm):
    # This tests the factory function
    # Mocking create_react_agent to avoid actual graph compilation
    with patch("commitai.agent.create_react_agent") as mock_create_graph:
        mock_graph = MagicMock()
        mock_create_graph.return_value = mock_graph
        agent_executor = create_commit_agent(mock_llm)
        assert agent_executor is not None


def test_agent_run(mock_llm):
    # E2E-ish test of the agent logic with mocks
    with patch("commitai.agent.create_react_agent") as mock_create_graph:
        mock_graph = MagicMock()
        mock_create_graph.return_value = mock_graph

        # Determine strict return structure for LangGraph invoke
        # It yields a dict with "messages" list
        last_message = MagicMock()
        last_message.content = "Final Commit Message"
        mock_graph.invoke.return_value = {"messages": [last_message]}

        agent_runnable = create_commit_agent(mock_llm)
        result = agent_runnable.invoke({"diff": "diff", "explanation": "expl"})

        assert result == "Final Commit Message"
