from __future__ import annotations

from typing import Any, Dict, List

import pytest

from python_ray.ray import Ray
from python_ray.settings import Settings
from python_ray.payloads import LogPayload
from python_ray.support.rate_limiter import RateLimiter


class DummyClient:
    def __init__(self) -> None:
        self.sent: List[Any] = []

    def send(self, request: Any) -> None:  # pragma: no cover - simple collector
        self.sent.append(request)


@pytest.fixture(autouse=True)
def reset_ray_class_state(monkeypatch):
    # Ensure tests don't leak class-level state
    Ray._client = None
    Ray._counters = None
    Ray._limiters = None
    Ray._rate_limiter = RateLimiter.disabled()
    Ray._enabled = None
    Ray._project_name = ""
    Ray._before_send_request = None
    Ray._stopwatches.clear()
    Ray._caught_exceptions.clear()
    yield
    Ray._client = None
    Ray._counters = None
    Ray._limiters = None
    Ray._rate_limiter = RateLimiter.disabled()
    Ray._enabled = None
    Ray._project_name = ""
    Ray._before_send_request = None
    Ray._stopwatches.clear()
    Ray._caught_exceptions.clear()


def make_ray(monkeypatch) -> tuple[Ray, DummyClient]:
    settings = Settings({})
    r = Ray(settings)
    dummy = DummyClient()
    monkeypatch.setattr(Ray, "_client", dummy, raising=False)
    return r, dummy


def test_send_wraps_values_in_payloads(monkeypatch):
    r, dummy = make_ray(monkeypatch)

    r.send("hello", 123)

    assert len(dummy.sent) == 1
    req = dummy.sent[0]
    assert len(req.payloads) == 2
    assert all(isinstance(p, LogPayload) for p in req.payloads)


def test_raw_bypasses_payload_factory(monkeypatch):
    r, dummy = make_ray(monkeypatch)

    r.raw("hello")

    assert len(dummy.sent) == 1
    [payload] = dummy.sent[0].payloads
    assert isinstance(payload, LogPayload)


def test_disabled_ray_does_not_send(monkeypatch):
    r, dummy = make_ray(monkeypatch)
    r.disable()

    r.send("hello")

    assert dummy.sent == []


def test_rate_limiter_blocks_and_notifies(monkeypatch):
    r, dummy = make_ray(monkeypatch)

    rl = RateLimiter(max_calls=1, max_per_second=None)
    Ray._rate_limiter = rl

    # First call passes
    r.send("first")
    assert len(dummy.sent) == 1

    # Second call exceeds max_calls, should trigger notification request
    r.send("second")

    # We expect two total requests: the original payload and a "Rate limit" custom payload
    assert len(dummy.sent) == 2
    meta_types = [p.get_type() for p in dummy.sent[1].payloads]
    assert meta_types == ["custom"]


def test_limit_and_once(monkeypatch):
    r, dummy = make_ray(monkeypatch)

    # limit: at most 2 sends from the same origin
    r.limit(2)
    r.send("a")
    r.send("b")
    r.send("c")  # should be suppressed by limiter

    # we don't know exact origin fingerprint, but can assert at most 2
    assert len(dummy.sent) <= 2

    # once: only a single payload from a given origin
    r2, dummy2 = make_ray(monkeypatch)
    r2.once("x")
    r2.send("y")  # should be suppressed
    assert len(dummy2.sent) == 1


def test_catch_and_throw_exceptions(monkeypatch):
    r, dummy = make_ray(monkeypatch)

    def bad_callable(ray: Ray) -> None:  # noqa: ARG001
        raise ValueError("boom")

    # Exception should be captured instead of raised
    r.send(bad_callable)

    # catch without callback sends exception payload
    r.catch()
    assert any(p.get_type() == "exception" for req in dummy.sent for p in req.payloads)

    # Now add another failing callable and rethrow
    r2, _ = make_ray(monkeypatch)

    def bad_callable2(ray: Ray) -> None:  # noqa: ARG001
        raise RuntimeError("oops")

    r2.send(bad_callable2)

    with pytest.raises(RuntimeError):
        r2.throw_exceptions()


def test_helper_ray_uses_settings_factory_and_sends(monkeypatch):
    from python_ray import helpers as helpers_module

    # Track that SettingsFactory.create_from_file is called and control the Settings
    from python_ray.settings import Settings

    called: dict[str, bool] = {"called": False}

    def fake_create_from_file() -> Settings:
        called["called"] = True
        return Settings({"host": "test-host", "port": 12345})

    monkeypatch.setattr(
        helpers_module.SettingsFactory,
        "create_from_file",
        staticmethod(fake_create_from_file),
    )

    sent_args: dict[str, tuple] = {}

    def fake_send(self: Ray, *values):  # type: ignore[return-type]
        sent_args["values"] = values
        return self

    monkeypatch.setattr(Ray, "send", fake_send)

    r = helpers_module.ray("hello", 42)

    assert called["called"] is True
    assert isinstance(r, Ray)
    assert sent_args["values"] == ("hello", 42)


def test_helper_rd_calls_ray_and_exits(monkeypatch):
    from python_ray import helpers as helpers_module

    called: dict[str, tuple] = {}

    def fake_ray(*values):
        called["values"] = values
        return object()

    monkeypatch.setattr(helpers_module, "ray", fake_ray)

    with pytest.raises(SystemExit) as exc:
        helpers_module.rd("bye", 123)

    assert exc.value.code == 0
    assert called["values"] == ("bye", 123)


def test_object_with_primitives_uses_send(monkeypatch):
    r, _ = make_ray(monkeypatch)

    sent: dict[str, tuple] = {}

    def fake_send(self: Ray, *values):  # type: ignore[return-type]
        sent["values"] = values
        return self

    monkeypatch.setattr(Ray, "send", fake_send)

    r.object(123)
    assert sent["values"] == (123,)


def test_object_with_dataclass_uses_text_json(monkeypatch):
    import json
    from dataclasses import dataclass

    @dataclass
    class User:
        id: int
        name: str

    r, _ = make_ray(monkeypatch)

    captured: dict[str, str] = {}

    def fake_text(self: Ray, text: str):  # type: ignore[return-type]
        captured["text"] = text
        return self

    monkeypatch.setattr(Ray, "text", fake_text)

    r.object(User(id=1, name="Ada"))

    payload = json.loads(captured["text"])
    assert payload == {"id": 1, "name": "Ada"}


def test_object_with_model_dump_and_dict_and_dunder_dict(monkeypatch):
    import json

    class ModelDump:
        def model_dump(self) -> dict[str, int]:
            return {"x": 1}

    class DictModel:
        def dict(self) -> dict[str, int]:
            return {"y": 2}

    class DunderDict:
        def __init__(self) -> None:
            self.public = 3
            self._private = 4

    r, _ = make_ray(monkeypatch)

    captured: list[str] = []

    def fake_text(self: Ray, text: str):  # type: ignore[return-type]
        captured.append(text)
        return self

    monkeypatch.setattr(Ray, "text", fake_text)

    r.object(ModelDump())
    r.object(DictModel())
    r.object(DunderDict())

    decoded = [json.loads(t) for t in captured]
    assert decoded[0] == {"x": 1}
    assert decoded[1] == {"y": 2}
    # Only public attribute should be present
    assert decoded[2] == {"public": 3}


def test_object_fallback_uses_repr(monkeypatch):
    class NoSpecial:
        pass

    value = NoSpecial()

    r, _ = make_ray(monkeypatch)

    sent: dict[str, tuple] = {}

    def fake_send(self: Ray, *values):  # type: ignore[return-type]
        sent["values"] = values
        return self

    monkeypatch.setattr(Ray, "send", fake_send)

    r.object(value)

    # Should have fallen back to repr(value)
    (arg,) = sent["values"]
    assert isinstance(arg, str)
    assert "NoSpecial" in arg
