#!/usr/bin/env python
"""
Convert strings to LaTeX strings in math environment used by matplotlib's
usetex

This module was written by Matthias Cuntz while at Department of Computational
Hydrosystems, Helmholtz Centre for Environmental Research - UFZ, Leipzig,
Germany, and continued while at Institut National de Recherche pour
l'Agriculture, l'Alimentation et l'Environnement (INRAE), Nancy, France.

:copyright: Copyright 2015- Matthias Cuntz, see AUTHORS.rst for details.
:license: MIT License, see LICENSE for details.

.. moduleauthor:: Matthias Cuntz

The following functions are provided:

.. autosummary::
   str2tex

History
   * Written Oct 2015 by Matthias Cuntz (mc (at) macu (dot) de)
   * Use raw strings for escaped characters, Nov 2021, Matthias Cuntz
   * Bug in space2linebreak in complex strings with spaces;
     do space2linebreak first, Nov 2021, Matthias Cuntz
   * Bug in escaping %, Nov 2021, Matthias Cuntz
   * Remove trailing $\\mathrm{}$, Nov 2021, Matthias Cuntz
   * Ported into pyjams, Nov 2021, Matthias Cuntz
   * Better handling of linebreaks in Matplotlib and LaTeX mode,
     Nov 2021, Matthias Cuntz
   * More consistent docstrings, Jan 2022, Matthias Cuntz
   * Use input2array, array2input, Jun 2022, Matthias Cuntz
   * Do not escape % if not usetex, Apr 2023, Matthias Cuntz
   * Add unicode symbol degree \u00B0, which get replaced by ^\\circ
     if usetex==True, Apr 2023, Matthias Cuntz
   * flake8 compliant, Oct 2024, Matthias Cuntz

"""
from .helper import input2array, array2input


__all__ = ['str2tex']


def str2tex(strin, space2linebreak=False,
            bold=False, italic=False, usetex=False):
    """
    Convert strings to LaTeX strings in math environment used by matplotlib's
    usetex

    Strings are embedded into '$\\mathrm{strin}$' by default but can be
    embedded into '\\mathbf' and '\\mathit'.
    Spaces are escaped but can be replaced by linebreaks.

    Parameters
    ----------
    strin : str or array-like of str
        string (array)
    space2linebreak : bool, optional
        Replace space (' ') by linebreak ('\\n') if True (default: False)
    bold : bool, optional
        Use '\\mathbf' instead of '\\mathrm' if True (default: False)
    italic : bool, optional
        Use '\\mathit' instead of '\\mathrm' if True (default: False)
    usetex : bool, optional
        Treat only linebreaks and comments if False (default)

    Returns
    -------
    string : str
        string (array) that can be used in matplotlib independent of usetex.

    Examples
    --------
    .. code-block:: python

       fig = plt.figure()
       tit = str2tex('A $S_{Ti}$ is great\\nbut use-less', usetex=usetex)
       fig.suptitle(tit)

    """
    import matplotlib.pyplot as plt

    # Input type and shape
    # use list because numpy array cannot change to raw strings
    istrin = list(input2array(strin, undef='', default=''))

    # font style
    if (bold + italic) > 1:
        raise ValueError('bold and italic are mutually exclusive.')
    else:
        if bold:
            mtex = r'$\mathbf{'
            ttex = r'$\textbf{'
            empty = r'$\mathbf{}$'
        elif italic:
            mtex = r'$\mathit{'
            ttex = r'$\textit{'
            empty = r'$\mathit{}$'
        else:
            mtex = r'$\mathrm{'
            ttex = r'$\textrm{'
            empty = r'$\mathrm{}$'

    # helpers
    a0 = chr(0)  # ascii 0
    # string replacements
    if usetex:
        # no '\n' in LaTeX, use '\newline'
        rep_n = lambda s: s.replace(r'\n', '}$' + a0 + r'\newline' + a0 + mtex)
        rep_newline = lambda s: s.replace(r'\newline', '}$' + a0 + r'\newline'
                                          + a0 + mtex)
    else:
        # '\n' has to be unicode string and not raw string in Matplotlib
        rep_n = lambda s: s.replace(r'\n', '' + a0 + '\n' + a0 + '')
        rep_newline = lambda s: s.replace(r'\newline', '' + a0 + '\n' +
                                          a0 + '')
    rep_down = lambda s: s.replace('_', r'\_')
    rep_up = lambda s: s.replace('^', r'\^')
    rep_hash = lambda s: s.replace('#', r'\#')
    rep_percent = lambda s: s.replace('%', r'\%')
    rep_nopercent = lambda s: s.replace(r'\%', '%')
    rep_space = lambda s: s.replace(' ', r'\ ')
    rep_minus = lambda s: s.replace('-', '}$' + ttex + '-}$' + mtex)
    rep_degree = lambda s: s.replace(u'\u00B0', r'^\circ{}')
    rep_a02empty = lambda s: s.replace(a0, '')
    if usetex or (plt.get_backend() == 'pdf'):
        rep_space2n = lambda s: s.replace(' ', '}$' + a0 + r'\newline'
                                          + a0 + mtex)
    else:
        rep_space2n = lambda s: s.replace(' ', '' + '\n' + '')
    rep_empty = lambda s: s.replace(empty, '')

    if usetex:
        for j, s in enumerate(istrin):
            if '$' in s:
                cleanempty = empty not in s
                ss = s.split('$')
                # outside $...$
                # -, _, ^ only escaped if not between $
                for ii in range(0, len(ss), 2):
                    ss[ii] = mtex + ss[ii] + '}$'
                    # - not minus sign
                    if '-' in ss[ii]:
                        ss[ii] = rep_minus(ss[ii])
                        if ss[ii].endswith('{}$'):
                            ss[ii] = ss[ii][:-11]  # rm trailing $\mathrm{}$
                    # \n not in tex mode but normal matplotlib
                    if (r'\n' in ss[ii]) and not (r'\newline' in ss[ii]):
                        ss[ii] = rep_n(ss[ii])
                    elif (r'\newline' in ss[ii]):
                        ss[ii] = rep_newline(ss[ii])
                    # escape _
                    if '_' in ss[ii]:
                        ss[ii] = rep_down(ss[ii])
                    # escape ^
                    if '^' in ss[ii]:
                        ss[ii] = rep_up(ss[ii])
                    # escape #
                    if '#' in ss[ii]:
                        ss[ii] = rep_hash(ss[ii])
                    # escape %
                    if ('%' in ss[ii]) and not (r'\%' in ss[ii]):
                        ss[ii] = rep_percent(ss[ii])
                    # replace unicode degree symbol (after escape of ^)
                    if u'\u00B0' in ss[ii]:
                        ss[ii] = rep_degree(ss[ii])
                    if space2linebreak:
                        if ' ' in ss[ii]:
                            ss[ii] = rep_space2n(ss[ii])

                # reassemble string
                istrin[j] = '$'.join(ss)

                if s[0] == '$':
                    # rm leading $\mathrm{}$ if string started with $
                    istrin[j] = istrin[j][11:]
            else:
                cleanempty = True
                istrin[j] = mtex + s + r'}$'
                # - not minus sign
                if '-' in istrin[j]:
                    istrin[j] = rep_minus(istrin[j])
                    if istrin[j].endswith('{}$'):
                        istrin[j] = istrin[j][:-11]  # rm trailing $\mathrm{}$
                # \n not in tex mode but normal matplotlib
                if (r'\n' in istrin[j]) and not (r'\newline' in istrin[j]):
                    istrin[j] = rep_n(istrin[j])
                elif (r'\newline' in istrin[j]):
                    istrin[j] = rep_newline(istrin[j])
                # escape _
                if '_' in istrin[j]:
                    istrin[j] = rep_down(istrin[j])
                # escape ^
                if '^' in istrin[j]:
                    istrin[j] = rep_up(istrin[j])
                # escape #
                if '#' in istrin[j]:
                    istrin[j] = rep_hash(istrin[j])
                # escape %
                if ('%' in istrin[j]) and not (r'\%' in istrin[j]):
                    istrin[j] = rep_percent(istrin[j])
                # replace unicode degree symbol (after escape of ^)
                if u'\u00B0' in istrin[j]:
                    istrin[j] = rep_degree(istrin[j])
                if space2linebreak:
                    if ' ' in istrin[j]:
                        istrin[j] = rep_space2n(istrin[j])

            # rm $\mathrm{}$
            if cleanempty:
                istrin[j] = rep_empty(istrin[j])

            # escape space
            if ' ' in istrin[j]:
                istrin[j] = rep_space(istrin[j])

            # rm ascii character 0 around linebreaks introduced above
            if a0 in istrin[j]:
                istrin[j] = rep_a02empty(istrin[j])
    else:
        # # escape %
        # istrin = [ rep_percent(i) if ('%' in i) and not (r'\%' in i) else i
        #            for i in istrin ]
        # do not escape %
        istrin = [ rep_nopercent(i) if (r'\%' in i) else i for i in istrin ]
        # '\n' is Matplotlib but not LaTeX
        for j, s in enumerate(istrin):
            if (r'\n' in istrin[j]) and not (r'\newline' in istrin[j]):
                istrin[j] = rep_n(istrin[j])
            elif (r'\newline' in istrin[j]):
                istrin[j] = rep_newline(istrin[j])
            if a0 in istrin[j]:
                istrin[j] = rep_a02empty(istrin[j])
            if space2linebreak:
                if ' ' in istrin[j]:
                    istrin[j] = rep_space2n(istrin[j])
        if (plt.get_backend() == 'pdf'):  # pragma: no cover
            # pdf backend uses LaTeX
            istrin = [ i.replace(r'\n', r'\newline')
                       if (r'\n' in i) and not (r'\newline' in i) else i
                       for i in istrin ]

    # Return right type
    out = array2input(istrin, strin, undef='')
    return out


if __name__ == '__main__':
    import doctest
    doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
