"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved.

array-api-jax-compat: Array-API JAX compatibility
"""

# pylint: disable=redefined-builtin


from __future__ import annotations

from typing import Any

import plum
from jax.experimental.array_api import __array_api_version__
from jaxtyping import ArrayLike, install_import_hook

with install_import_hook("array_api_jax_compat", None):
    from . import (
        _constants,
        _creation_functions,
        _data_type_functions,
        _elementwise_functions,
        _indexing_functions,
        _linear_algebra_functions,
        _manipulation_functions,
        _searching_functions,
        _set_functions,
        _sorting_functions,
        _statistical_functions,
        _utility_functions,
        fft,
        linalg,
    )
    from ._constants import *
    from ._creation_functions import *
    from ._data_type_functions import *
    from ._elementwise_functions import *
    from ._indexing_functions import *
    from ._linear_algebra_functions import *
    from ._manipulation_functions import *
    from ._searching_functions import *
    from ._set_functions import *
    from ._sorting_functions import *
    from ._statistical_functions import *
    from ._utility_functions import *
    from ._version import version as __version__

__all__ = ["__version__", "__array_api_version__", "fft", "linalg"]
__all__ += _constants.__all__
__all__ += _creation_functions.__all__
__all__ += _data_type_functions.__all__
__all__ += _elementwise_functions.__all__
__all__ += _indexing_functions.__all__
__all__ += _linear_algebra_functions.__all__
__all__ += _manipulation_functions.__all__
__all__ += _searching_functions.__all__
__all__ += _set_functions.__all__
__all__ += _sorting_functions.__all__
__all__ += _statistical_functions.__all__
__all__ += _utility_functions.__all__


# Simplify the display of ArrayLike
plum.activate_union_aliases()
plum.set_union_alias(ArrayLike, "ArrayLike")


def __getattr__(name: str) -> Any:  # TODO: fuller annotation
    """Forward all other attribute accesses to Quaxified JAX."""
    import jax  # pylint: disable=C0415,W0621
    from quax import quaxify  # pylint: disable=C0415,W0621

    # TODO: detect if the attribute is a function or a module.
    # If it is a function, quaxify it. If it is a module, return a proxy object
    # that quaxifies all of its attributes.
    return quaxify(getattr(jax, name))
