import json
from collections import OrderedDict
from typing import Union

from requests import Response, Session
from requests_toolbelt.multipart.decoder import MultipartDecoder
from urllib3.fields import RequestField
from urllib3.filepost import encode_multipart_formdata

"""
Defines classes to simplify usage of the documents REST endpoint defined at
https://docs.marklogic.com/REST/client/management. 
"""


class Metadata:
    """
    Defines the metadata properties that can be associated with a document and also
    used for specifying default metadata when writing many documents. One benefit
    of this class - besides encapsulating each bit of what MarkLogic defines as metadata
    for a document - is to provide a simpler mechanism for defining permissions via a
    dictionary as opposed to an array of dictionaries.

    :param collections: array of collection URIs.
    :param permissions: dict with keys of role names and values of arrays of
    capabilities such as "read", "update", and "execute".
    :param quality: document quality, used for scoring in searches.
    :param metadata_values: dict with string keys and string values.
    :param properties: dict with string keys and values of any type.
    """

    def __init__(
        self,
        collections: list[str] = None,
        permissions: dict = None,
        quality: int = None,
        metadata_values: dict = None,
        properties: dict = None,
    ):
        self.collections = collections
        self.permissions = permissions
        self.quality = quality
        self.metadata_values = metadata_values
        self.properties = properties


def metadata_to_dict(metadata: Metadata) -> dict:
    """
    Returns a dictionary with a structure matching what the /v1/documents endpoint
    requires.
    """
    md = {}
    if metadata.permissions:
        md["permissions"] = [
            {"role-name": k, "capabilities": v} for k, v in metadata.permissions.items()
        ]
    if metadata.collections:
        md["collections"] = metadata.collections
    if metadata.quality:
        md["quality"] = metadata.quality
    if metadata.properties:
        md["properties"] = metadata.properties
    if metadata.metadata_values:
        md["metadataValues"] = metadata.metadata_values
    return md


def dict_to_metadata(metadata: dict, target_metadata: Metadata) -> None:
    """
    Populates the given Metadata instance based on the metadata dictionary as returned
    by the /v1/documents REST endpoint.
    """
    target_metadata.collections = metadata.get("collections")
    target_metadata.quality = metadata.get("quality")
    target_metadata.metadata_values = metadata.get("metadataValues")
    target_metadata.properties = metadata.get("properties")
    if metadata.get("permissions"):
        perms = {}
        for perm in metadata["permissions"]:
            role = perm["role-name"]
            perms[role] = perm["capabilities"]
        target_metadata.permissions = perms
    else:
        target_metadata.perms = None


class Document(Metadata):
    """
    Represents a document, either as read from MarkLogic or as a document to be
    written to MarkLogic.
    """

    def __init__(
        self,
        uri: str = None,
        content=None,
        collections: list[str] = None,
        permissions: dict = None,
        quality: int = None,
        metadata_values: dict = None,
        properties: dict = None,
        content_type: str = None,
        version_id: str = None,
        extension: str = None,
        directory: str = None,
        repair: str = None,
        extract: str = None,
        temporal_document: str = None,
    ):
        """
        :param uri: the URI of the document; can be None when relying on MarkLogic to
        generate a URI.
        :param content: the content of the document.
        :param collections: see definition in parent class.
        :param permissions: see definition in parent class.
        :param quality: see definition in parent class.
        :param metadata_values: see definition in parent class.
        :param properties: see definition in parent class.
        :param content_type: the MIME type of the document; use when MarkLogic cannot
        determine the MIME type based on the URI.
        :param version_id: affects updates when optimistic locking is enabled; see
        https://docs.marklogic.com/REST/POST/v1/documents for more information.
        :param temporal_document: the logical document URI for a document written to a
        :param extension: specifies a suffix for a URI generated by MarkLogic; only used
        when writing a document.
        :param directory: specifies a prefix for a URI generated by MarkLogic; only used
        when writing a document.
        :param repair: for an XML document, the level of XML repair to perform; can be
        "full" or "none", with "none" being the default; only used when writing a
        document.
        temporal collection; requires that a "temporal-collection" parameter be
        included in the request; only used when writing a document.
        """
        super().__init__(collections, permissions, quality, metadata_values, properties)
        self.uri = uri
        self.content = content
        self.content_type = content_type
        self.version_id = version_id

        # The following are all specific to writing a document.
        self.extension = extension
        self.directory = directory
        self.repair = repair
        self.extract = extract
        self.temporal_document = temporal_document

    @property
    def metadata(self):
        """
        Returns a dict containing the 5 attributes that comprise the metadata of a 
        document in MarkLogic.
        """
        return {
            "permissions": self.permissions,
            "collections": self.collections,
            "quality": self.quality,
            "metadata_values": self.metadata_values,
            "properties": self.properties,
        }

    def __repr__(self):
        # Print all class attributes for easy inspection.
        return "{!r}".format(self.__dict__)

    def to_request_field(self) -> RequestField:
        """
        Returns a multipart request field representing the document to be written.
        """
        if self.content is None:
            return None
        data = self.content
        if type(data) is dict:
            data = json.dumps(data)
        field = RequestField(name=self.uri, data=data, filename=self.uri)
        field.make_multipart(
            content_disposition=self._make_content_disposition(),
            content_type=self.content_type,
        )
        return field

    def to_metadata_request_field(self) -> RequestField:
        """
        Returns a multipart request field if any metadata has been set on this
        document; returns None otherwise.
        """
        metadata = metadata_to_dict(self)
        if len(metadata.keys()) == 0:
            return None

        field = RequestField(
            name=self.uri, data=json.dumps(metadata), filename=self.uri
        )
        field.make_multipart(
            content_disposition=f"attachment; filename={self.uri}; category=metadata",
            content_type="application/json",
        )
        return field

    def _make_content_disposition(self) -> str:
        """
        Returns a content disposition suitable for use when writing documents via
        https://docs.marklogic.com/REST/POST/v1/documents . See that page for more
        information on each part of the disposition.
        """
        disposition = "attachment"

        if not self.uri:
            disposition = "inline"
            if self.extension:
                disposition = f"{disposition};extension={self.extension}"
            if self.directory:
                disposition = f"{disposition};directory={self.directory}"

        if self.repair:
            disposition = f"{disposition};repair={self.repair}"

        if self.extract:
            disposition = f"{disposition};extract={self.extract}"

        if self.version_id:
            disposition = f"{disposition};versionId={self.version_id}"

        if self.temporal_document:
            disposition = f"{disposition};temporal-document={self.temporal_document}"

        return disposition


class DefaultMetadata(Metadata):
    """
    Defines default metadata for use when writing many documents at one time.
    """

    def __init__(
        self,
        collections: list[str] = None,
        permissions: dict = None,
        quality: int = None,
        metadata_values: dict = None,
        properties: dict = None,
    ):
        super().__init__(collections, permissions, quality, metadata_values, properties)

    def to_metadata_request_field(self) -> RequestField:
        """
        Returns a multipart request field suitable for use when writing many documents.
        """
        metadata = metadata_to_dict(self)
        if len(metadata.keys()) == 0:
            return None
        field = RequestField(name=None, data=json.dumps(metadata), filename=None)
        field.make_multipart(
            content_disposition="inline; category=metadata",
            content_type="application/json",
        )
        return field


def _extract_values_from_header(part) -> dict:
    """
    Returns a dict containing values about the document content or metadata.
    """
    encoding = part.encoding
    disposition = part.headers["Content-Disposition".encode(encoding)].decode(encoding)
    disposition_values = {}
    for item in disposition.split(";"):
        tokens = item.split("=")
        # The first item will be "attachment" and can be ignored.
        if len(tokens) == 2:
            disposition_values[tokens[0].strip()] = tokens[1]

    content_type = None
    if part.headers.get("Content-Type".encode(encoding)):
        content_type = part.headers["Content-Type".encode(encoding)].decode(encoding)

    uri = disposition_values["filename"]
    if uri.startswith('"'):
        uri = uri[1:]
    if uri.endswith('"'):
        uri = uri[:-1]

    return {
        "uri": uri,
        "category": disposition_values["category"],
        "content_type": content_type,
        "version_id": disposition_values.get("versionId"),
    }


def multipart_response_to_documents(response: Response) -> list[Document]:
    """
    Returns a list of Documents, one for each URI found in the various parts in the
    given multipart response. The response is assumed to correspond to the structure
    defined by https://docs.marklogic.com/REST/GET/v1/documents when the Accept header
    is "multipart/mixed".
    """
    decoder = MultipartDecoder.from_response(response)

    uris_to_documents = OrderedDict()

    for part in decoder.parts:
        header_values = _extract_values_from_header(part)
        uri = header_values["uri"]
        if header_values["category"] == "content":
            content = part.content
            content_type = header_values.get("content_type")
            if content_type == "application/json":
                content = json.loads(content)
            elif content_type in ["application/xml", "text/xml", "text/plain"]:
                content = content.decode(part.encoding)

            version_id = header_values.get("version_id")
            if uris_to_documents.get(uri):
                doc: Document = uris_to_documents[uri]
                doc.content = content
                doc.content_type = content_type
                doc.version_id = version_id
            else:
                uris_to_documents[uri] = Document(
                    uri, content, content_type=content_type, version_id=version_id
                )
        else:
            doc = (
                uris_to_documents[uri]
                if uris_to_documents.get(uri)
                else Document(uri, None)
            )
            uris_to_documents[uri] = doc
            dict_to_metadata(json.loads(part.content), doc)

    return list(uris_to_documents.values())


class DocumentManager:
    """
    Provides methods to simplify interacting with REST endpoints that either accept
    or return documents. Primarily involves endpoints defined at
    https://docs.marklogic.com/REST/client/management , but also includes support for
    the search endpoint at https://docs.marklogic.com/REST/POST/v1/search which can
    return documents as well.
    """

    def __init__(self, session: Session):
        self._session = session

    def write(
        self, parts: Union[Document, list[Union[DefaultMetadata, Document]]], **kwargs
    ) -> Response:
        """
        Write one or many documents at a time via a POST to the endpoint defined at
        https://docs.marklogic.com/REST/POST/v1/documents .

        :param parts: a part can define either a document to be written, which can
        include metadata, or a set of default metadata to be applied to each document
        after it that does not define its own metadata. See
        https://docs.marklogic.com/guide/rest-dev/bulk#id_16015 for more information on
        how the REST endpoint uses metadata.
        """
        fields = []

        if isinstance(parts, Document):
            parts = [parts]

        for part in parts:
            if isinstance(part, DefaultMetadata):
                fields.append(part.to_metadata_request_field())
            else:
                metadata_field = part.to_metadata_request_field()
                if metadata_field:
                    fields.append(metadata_field)
                content_field = part.to_request_field()
                if content_field:
                    fields.append(content_field)

        data, content_type = encode_multipart_formdata(fields)

        headers = kwargs.pop("headers", {})
        headers["Content-Type"] = "".join(
            ("multipart/mixed",) + content_type.partition(";")[1:]
        )
        if not headers.get("Accept"):
            headers["Accept"] = "application/json"

        return self._session.post("/v1/documents", data=data, headers=headers, **kwargs)

    def read(
        self, uris: Union[str, list[str]], categories: list[str] = None, **kwargs
    ) -> Union[list[Document], Response]:
        """
        Read one or many documents via a GET to the endpoint defined at
        https://docs.marklogic.com/REST/POST/v1/documents . If a 200 is not returned
        by that endpoint, then the Response is returned instead.

        :param uris: list of URIs or a single URI to read.
        :param categories: optional list of the categories of data to return for each
        URI. By default, only content will be returned for each URI. See the endpoint
        documentation for further information.
        """
        params = kwargs.pop("params", {})
        params["uri"] = uris if isinstance(uris, list) else [uris]
        params["format"] = "json"  # This refers to the metadata format.
        if categories:
            params["category"] = categories

        headers = kwargs.pop("headers", {})
        headers["Accept"] = "multipart/mixed"
        response = self._session.get(
            "/v1/documents", params=params, headers=headers, **kwargs
        )

        return (
            multipart_response_to_documents(response)
            if response.status_code == 200
            else response
        )

    def search(
        self,
        q: str = None,
        query: Union[dict, str] = None,
        categories: list[str] = None,
        start: int = None,
        page_length: int = None,
        options: str = None,
        collections: list[str] = None,
        **kwargs,
    ) -> Union[list[Document], Response]:
        """
        Leverages the support in the search endpoint defined at
        https://docs.marklogic.com/REST/POST/v1/search for returning a list of
        documents instead of a search response. Parameters that are commonly used for
        that endpoint are included as arguments to this method for ease of use.

        :param query: JSON or XML query matching one of the types supported by the
        search endpoint. The "Content-type" header will be set based on whether this
        is a dict, a string of JSON, or a string of XML.
        :param categories: optional list of the categories of data to return for each
        URI. By default, only content will be returned for each URI. See the endpoint
        documentation for further information.
        :param q: optional search string.
        :param start: index of the first result to return.
        :param page_length: maximum number of documents to return.
        :param options: name of a query options instance to use.
        :param collections: restrict results to documents in these collections.
        """
        params = kwargs.pop("params", {})
        params["format"] = "json"  # This refers to the metadata format.
        if categories:
            params["category"] = categories
        if collections:
            params["collection"] = collections
        if q:
            params["q"] = q
        if start:
            params["start"] = start
        if page_length:
            params["pageLength"] = page_length
        if options:
            params["options"] = options

        headers = kwargs.pop("headers", {})
        headers["Accept"] = "multipart/mixed"
        data = None

        if query:
            if isinstance(query, dict):
                data = json.dumps(query)
                headers["Content-type"] = "application/json"
            else:
                data = query
                try:
                    json.loads(query)
                except Exception:
                    headers["Content-type"] = "application/xml"
                else:
                    headers["Content-type"] = "application/json"

        if data:
            response = self._session.post(
                "/v1/search",
                headers=headers,
                params=params,
                data=data,
                **kwargs,
            )
        else:
            response = self._session.post(
                "/v1/search", headers=headers, params=params, **kwargs
            )

        return (
            multipart_response_to_documents(response)
            if response.status_code == 200
            else response
        )
