import pandas as pd
import pytest

from polytope_feature.polytope import Polytope, Request
from polytope_feature.shapes import Box, Select, Span

# import geopandas as gpd
# import matplotlib.pyplot as plt


class TestSlicingFDBDatacube:
    def setup_method(self, method):
        # Create a dataarray with 3 labelled axes using different index types
        self.options = {
            "axis_config": [
                {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]},
                {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]},
                {
                    "axis_name": "date",
                    "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}],
                },
                {
                    "axis_name": "values",
                    "transformations": [
                        {"name": "mapper", "type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}
                    ],
                },
                {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]},
                {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]},
            ],
            "compressed_axes_config": [
                "longitude",
                "latitude",
                "levtype",
                "step",
                "date",
                "domain",
                "expver",
                "param",
                "class",
                "stream",
                "type",
            ],
            "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"},
        }

    @pytest.mark.fdb
    def test_fdb_datacube(self):
        import pygribjump as gj

        request = Request(
            Select("step", [0]),
            Select("levtype", ["sfc"]),
            Select("date", [pd.Timestamp("20230625T120000")]),
            Select("domain", ["g"]),
            Select("expver", ["0001"]),
            Select("param", ["167"]),
            Select("class", ["od"]),
            Select("stream", ["oper"]),
            Select("type", ["an"]),
            Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]),
        )
        self.fdbdatacube = gj.GribJump()
        self.API = Polytope(
            datacube=self.fdbdatacube,
            options=self.options,
        )
        result = self.API.retrieve(request)
        result.pprint()
        assert len(result.leaves) == 3
        assert len(result.leaves[0].result) == 3
        assert result.leaves[0].flatten()["longitude"][1] == 0.070093457944
        assert result.leaves[1].flatten()["longitude"][1] == 0.070148090413
        assert result.leaves[2].flatten()["longitude"][1] == 0.070202808112
        assert result.leaves[0].result[0] == 297.9250183105469

        # lats = []
        # lons = []
        # tol = 1e-8
        # for i in range(len(result.leaves)):
        #     cubepath = result.leaves[i].flatten()
        #     lat = cubepath["latitude"]
        #     lon = cubepath["longitude"]
        #     lats.append(lat)
        #     lons.append(lon)

        # worldmap = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
        # fig, ax = plt.subplots(figsize=(12, 6))
        # worldmap.plot(color="darkgrey", ax=ax)

        # plt.scatter(lons, lats, s=16, c="red", cmap="YlOrRd")
        # plt.colorbar(label="Temperature")
        # plt.show()

    @pytest.mark.fdb
    def test_fdb_datacube_select_grid(self):
        import pygribjump as gj

        request = Request(
            Select("step", [0]),
            Select("levtype", ["sfc"]),
            Select("date", [pd.Timestamp("20230625T120000")]),
            Select("domain", ["g"]),
            Select("expver", ["0001"]),
            Select("param", ["167"]),
            Select("class", ["od"]),
            Select("stream", ["oper"]),
            Select("type", ["an"]),
            Select("latitude", [0.035149384216]),
            Span("longitude", 0, 0.070093457944),
        )
        self.fdbdatacube = gj.GribJump()
        self.API = Polytope(
            datacube=self.fdbdatacube,
            options=self.options,
        )
        result = self.API.retrieve(request)
        result.pprint()
        assert len(result.leaves) == 1
        assert len(result.leaves[0].result) == 2
        assert result.leaves[0].flatten()["longitude"] == (0.0, 0.070093457944)
