"""Simple codegen for typetree classes

- Supports nested types
- Supports inheritance
- Automatically resolves import order and dependencies
- Generated classes support nested types, and will be initialized with the correct types even with nested dicts

NOTE:
- Cannot resolve namespace conflicts if the same class name is defined in multiple namespaces
- Missing type definitions are marked with # XXX: Fallback of {org_type} and typedefed as object
- Circular inheritance is checked and raises RecursionError
- The output order (imports, classes) are deterministic and lexicographically sorted
- The output is emitted in lieu of the Namespace structure of the TypeTree dump, presented as Python modules in directories

USAGE:

    python typetree_codegen.py <typetree_dump.json> <output_dir>
"""

# From https://github.com/K0lb3/UnityPy/blob/master/generators/ClassesGenerator.py
BASE_TYPE_MAP = {
    "char": "str",
    "short": "int",
    "int": "int",
    "long long": "int",
    "unsigned short": "int",
    "unsigned int": "int",
    "unsigned long long": "int",
    "UInt8": "int",
    "UInt16": "int",
    "UInt32": "int",
    "UInt64": "int",
    "SInt8": "int",
    "SInt16": "int",
    "SInt32": "int",
    "SInt64": "int",
    "Type*": "int",
    "FileSize": "int",
    "float": "float",
    "double": "float",
    "bool": "bool",
    "string": "str",
    "TypelessData": "bytes",
    # -- Extra
    "Byte[]": "bytes",
    "Byte": "int",
    "String": "str",
    "Int32": "int",
    "Single": "float",
    "Color": "ColorRGBA",
    "Vector2": "Vector2f",
    "Vector3": "Vector3f",
    "Vector4": "Vector4f",
    "Quaternion": "Quaternionf",
}
# XXX: Can't use attrs here since subclassing MonoBehavior and such - though defined by the typetree dump
# seem to be only valid if the class isn't a property of another class
# In which case the MonoBehavior attributes are inherited by the parent class and does not
# initialize the property class
# XXX: Need some boilerplate to handle this
HEADER = "\n".join(
    [
        "# fmt: off",
        "# Auto-generated by https://github.com/mos9527/UnityPyTypetreeCodegen",
        "" "from typing import List, Union, Optional, TypeVar",
        "from UnityPy.files.ObjectReader import ObjectReader",
        "from UnityPy.classes import *",
        "from UnityPy.classes.math import (ColorRGBA, Matrix3x4f, Matrix4x4f, Quaternionf, Vector2f, Vector3f, Vector4f, float3, float4,)",
        '''T = TypeVar("T")
UTTCG_Classes = dict()
def UTTCGen(fullname: str, typetree: dict):
    """dataclass-like decorator for typetree classess with nested type support
    
    limitations:
    - the behavior is similar to slotted dataclasses where shared attributes are inherited
      but allows ommiting init of the parent if kwargs are not sufficient
    - generally supports nested types, however untested and could be slow	
    - and ofc, zero type checking and safeguards :/	
    """    
    REFERENCED_ARGS = {'object_reader'}
    def __inner(clazz: T) -> T:
        # Allow these to be propogated to the props
        def __init__(self, **d):        
            def reduce_init(clazz, **d):
                types : dict = clazz.__annotations__
                for k, sub in types.items():
                    if type(sub) == str:
                        sub = eval(sub) # attrs turns these into strings...why?
                    while sub.__name__ == "Optional":
                        sub = sub.__args__[0]  # Reduce Optional[T] -> T
                    reduce_arg = getattr(sub, "__args__", [None])[0]
                    if k in REFERENCED_ARGS: # Directly refcounted
                        reduce_arg = sub = lambda x: x                         
                    if isinstance(d[k], list):
                        if hasattr(reduce_arg, "__annotations__"):
                            setattr(self, k, [reduce_arg(**x) for x in d[k]])
                        else:
                            setattr(self, k, [reduce_arg(x) for x in d[k]])
                    elif isinstance(d[k], dict) and hasattr(sub, "__annotations__"):
                        setattr(self, k, sub(**d[k]))
                    else:
                        if isinstance(d[k], dict):
                            setattr(self, k, sub(**d[k]))
                        else:
                            setattr(self, k, sub(d[k]))
            def reduce_base(clazz, **d):	
                for __base__ in clazz.__bases__:
                    if hasattr(__base__, "__annotations__"):
                        types : dict = __base__.__annotations__
                        args = {k:d[k] for k in types if k in d}
                        if len(args) == len(types):
                            super(clazz, self).__init__(**args)
                            reduce_init(__base__, **d)                       
                    reduce_base(__base__, **d)
            reduce_base(clazz, **d)               
            reduce_init(clazz, **d)            
        def __repr__(self) -> str:
            return f"{clazz.__name__}({', '.join([f'{k}={getattr(self, k)!r}' for k in self.__annotations__])})"
        def __save(self):
            self.object_reader.save_typetree(self, self.__typetree__)
        clazz.__init__ = __init__
        clazz.__repr__ = __repr__
        clazz.__typetree__ = typetree
        clazz.save = __save
        UTTCG_Classes[fullname] = clazz
        return clazz
    return __inner

# Helper functions

def UTTCGen_AsInstance(src: MonoBehaviour | ObjectReader, fullname: str = None) -> T:
    """Instantiate a class from the typetree definition and the raw data.

    In most cases, this is the function you want to use.
    It will read the typetree data from the MonoBehaviour instance and instantiate the class with the data.

    Args:
        src (MonoBehaviour | ObjectReader): The MonoBehaviour instance or ObjectReader to read from.
        fullname (str): The full name of the class to read. If None, it will be read from the MonoBehaviour instance's `m_Script` property.

    Returns:
        T: An instance of the class defined by the typetree.
    """
    if not fullname and isinstance(src, MonoBehaviour):
        script = src.m_Script.read()
        fullname = script.m_ClassName
        if script.m_Namespace:
            fullname = f"{script.m_Namespace}.{fullname}"    
    clazz = UTTCG_Classes.get(fullname, None)
    assert clazz is not None, f"Class definition for {fullname} not found"    
    if isinstance(src, MonoBehaviour):
        src = src.object_reader
    raw_def = src.read_typetree(clazz.__typetree__, check_read=False)
    instance = clazz(object_reader=src, **raw_def)
    return instance
''',
    ]
)
from collections import defaultdict
import argparse, json


def translate_name(m_Name: str, **kwargs):
    """A la https://github.com/K0lb3/UnityPy/blob/b811a2942297b5d8107e9a10249df80a87492282/UnityPy/helpers/TypeTreeNode.py#L361.
    With extra handling for templated/generic types and reserved keywords.
    """
    NG = "<>|`="
    m_Name = m_Name.replace("<>", "__generic_")  # Generic templates
    m_Name = m_Name.replace("<", "_").replace(">", "_")  # Templated
    for c in NG:
        m_Name = m_Name.replace(c, "_")
    RESERVED_NAMES = {
        "class",
        "def",
        "return",
        "if",
        "else",
        "elif",
        "for",
        "while",
        "in",
        "is",
        "not",
        "and",
        "or",
        "from",
        "import",
        "as",
        "with",
        "try",
        "except",
        "finally",
        "raise",
        "assert",
        "break",
        "continue",
        "pass",
        "yield",
        "True",
        "False",
    }
    if m_Name in RESERVED_NAMES:
        m_Name = "_" + m_Name
    return m_Name


from UnityPy import classes as UnityBuiltin
from TypeTreeGeneratorAPI import TypeTreeNode
from logging import getLogger
from coloredlogs import install

import os, shutil
from typing import Dict, List

logger = getLogger("codegen")


def translate_type(
    m_Type: str, strip=False, fallback=True, typenames: dict = dict(), **kwargs
):
    if m_Type in BASE_TYPE_MAP:
        return BASE_TYPE_MAP[m_Type]
    if getattr(UnityBuiltin, m_Type, None):
        return m_Type
    if m_Type in typenames:
        return m_Type
    if m_Type.endswith("[]"):
        m_Type = translate_type(m_Type[:-2], strip, fallback, typenames)
        if not strip:
            return f"List[{m_Type}]"
        else:
            return m_Type
    if m_Type.startswith("PPtr<"):
        m_Type = translate_type(m_Type[5:-1], strip, fallback, typenames)
        if not strip:
            return f"PPtr[{m_Type}]"
        else:
            return m_Type
    if fallback:
        logger.warning(f"Unknown type {m_Type}, using fallback")
        return "object"
    else:
        return m_Type


def declare_field(name: str, type: str, org_type: str = None):
    name = translate_name(name)
    if type not in {"object", "List[object]", "PPtr[object]"}:
        return f"{name} : {type}"
    else:
        return f"{name} : {type} # XXX: Fallback of {org_type}"


from io import TextIOWrapper


def topsort(graph: dict):
    # Sort the keys in topological order
    # We don't assume the guarantee otherwise
    graph = {k: list(sorted(v)) for k, v in graph.items()}
    vis = defaultdict(lambda: 0)
    topo = list()

    def dfs(u):
        vis[u] = 1
        for v in graph.get(u, []):
            if vis[v] == 1:
                return False
            if vis[v] == 0 and not dfs(v):
                return False
        vis[u] = 2
        topo.append(u)
        return True

    flag = 1
    for clazz in graph:
        if not vis[clazz]:
            flag &= dfs(clazz)
    # XXX: Shouldn't happen. Need to figure out how this is possible
    # assert flag, "graph contains cycle"
    return topo


def process_namespace(
    f: TextIOWrapper,
    classname_nodes: Dict[str, List[TypeTreeNode]],
    namespace: str = "",
    import_root: str = "",
    import_defs: dict = dict(),
):
    def emit_line(*lines: str):
        for line in lines:
            f.write(line)
            f.write("\n")
        if not lines:
            f.write("\n")

    logger.info(
        f"Subpass 1: Generating class dependency graph for {namespace or "<default namespace>"}"
    )
    emit_line("# fmt: off")
    emit_line("# Auto-generated by https://github.com/mos9527/UnityPyTypetreeCodegen")
    emit_line(f"# Python definition for {namespace or "<default namespace>"}", "")
    if import_root:
        emit_line(f"from {import_root} import *")
    for clazz, parent in import_defs.items():
        emit_line(f"from {import_root}{parent or ''} import {clazz}")

    emit_line()
    # Emit by topo order
    graph = {
        clazz: {
            translate_type(field.m_Type, strip=True, fallback=False) for field in fields
        }
        for clazz, fields in classname_nodes.items()
    }
    topo = topsort(graph)
    clazzes = list()

    logger.info(f"Subpass 2: Generating code for {namespace}")
    dp = defaultdict(lambda: -1)
    for clazz in topo:
        fullname = f"{namespace}.{clazz}" if namespace else clazz
        fields = classname_nodes.get(clazz, None)
        if not fields:
            logger.debug(
                f"Class {clazz} has no fields defined in TypeTree dump, skipped"
            )
            continue
        # Heuristic: If there is a lvl1 field, it's a subclass
        lvl1 = list(filter(lambda field: field.m_Level == 1, fields))
        clazz = translate_name(clazz)
        clazzes.append(clazz)
        clazz_fields = list()

        def __encoder(obj):
            if isinstance(obj, TypeTreeNode):
                return obj.__dict__
            return obj

        clazz_typetree = json.dumps(fields, default=__encoder)
        emit_line(f"@UTTCGen('{fullname}', {clazz_typetree})")
        if lvl1:
            parent = translate_type(fields[0].m_Type, strip=True, fallback=False)
            emit_line(f"class {translate_name(clazz)}({translate_name(parent)}):")
            if dp[parent] == -1:
                # Reuse parent's fields with best possible effort
                if pa_dep1 := getattr(UnityBuiltin, parent, None):
                    dp[parent] = len(pa_dep1.__annotations__)
                else:
                    raise ValueError  # XXX: Should NEVER happen
            pa_dep1 = dp[parent]
            cur_dep1 = pa_dep1
            for dep, (i, field) in enumerate(
                filter(lambda field: field[1].m_Level == 1, enumerate(fields))
            ):
                if dep < pa_dep1:
                    # Skip parent fields at lvl1
                    continue
                if i + 1 < len(fields) and fields[i + 1].m_Type == "Array":
                    if field.m_Type.startswith("List"):
                        # Rename this to Type[]
                        field.m_Type = fields[i + 3].m_Type + "[]"
                name, type = field.m_Name, translate_type(
                    field.m_Type, typenames=classname_nodes | import_defs
                )
                emit_line(f"\t{declare_field(name, type, field.m_Type)}")
                clazz_fields.append((name, type, field.m_Type))
                cur_dep1 += 1
            dp[clazz] = cur_dep1
        else:
            # No inheritance
            emit_line(f"class {clazz}:")
            for field in fields:
                name, type = field.m_Name, translate_type(
                    field.m_Type, typenames=classname_nodes | import_defs
                )
                emit_line(f"\t{declare_field(name, type, field.m_Type)}")
                clazz_fields.append((name, type))
            dp[clazz] = len(fields)
        if not clazz_fields:
            # Empty class. Consider MRO
            emit_line("\tpass")


def process_typetree(fullname_nodes: Dict[str, List[TypeTreeNode]], outdir: str):
    handles: Dict[str, TextIOWrapper] = dict()

    def __open(fname: str):
        fname = os.path.join(outdir, fname)
        if fname not in handles:
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            handles[fname] = open(fname, "w")
        return handles[fname]

    namespaces = defaultdict(dict)
    namespacesT = defaultdict(None)
    logger.info("Pass 1: Building namespace")
    for key in fullname_nodes:
        fullkey = key.split(".")
        if len(fullkey) == 1:
            namespace, clazz = None, fullkey[0]
        else:
            namespace, clazz = fullkey[:-1], fullkey[-1]
            namespace = ".".join(namespace)
        namespaces[namespace][clazz] = fullname_nodes[key]
        if clazz not in namespacesT:
            namespacesT[clazz] = namespace
        else:
            logger.error(
                f"Class {clazz} already defined in {namespacesT[clazz]} but found again in {namespace}"
            )
            logger.error(
                f"Need manual intervention to resolve the conflict. Using first definition for now."
            )
    logger.info("Pass 2: Generating import graph")
    # Build import graph
    namespaceDeps = defaultdict(set)
    for namespace, classname_nodes in namespaces.items():
        for clazz, fields in classname_nodes.items():
            for i, field in enumerate(fields):
                if type(field) != TypeTreeNode:
                    field = fields[i] = TypeTreeNode(**field)

                ftype = translate_type(field.m_Type, strip=True, fallback=False)
                if ftype in namespacesT and namespacesT[ftype] != namespace:
                    namespaceDeps[namespace].add(ftype)

    logger.info("Pass 3: Emitting namespace as Python modules")
    __open("__init__.py").write(HEADER)
    # XXX: This part can be trivally parallelized
    for namespace, classname_nodes in sorted(
        namespaces.items(), key=lambda x: x[0].count(".") if x[0] else 0
    ):
        # CubismTaskHandler -> generated/__init__.py
        # Live2D.Cubism.Core.CubismMoc -> generated/Live2D/Cubism/Core/__init__.py
        if namespace:
            ndots = namespace.count(".") + 2
            dotss = "." * ndots
            f = __open(os.path.join(*namespace.split("."), "__init__.py"))
            deps = {k: namespacesT[k] for k in namespaceDeps[namespace]}
            deps = dict(sorted(deps.items()))
            process_namespace(f, classname_nodes, namespace, dotss, deps)
        else:
            f = __open("__init__.py")
            process_namespace(f, classname_nodes, namespace)


import re, logging
from UnityPy.helpers.TypeTreeGenerator import TypeTreeGenerator


def __main__():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--files",
        help="Load typetree dump from game assembly folder containing the DLLs",
    )
    parser.add_argument(
        "--json",
        help="Load tree dump in json format {str[fullname]: List[TypeTreeNode]},...",
    )
    parser.add_argument(
        "--unity-version",
        help="Unity version to use for typetree generation",
        default="2022.3.21f1",
    )
    parser.add_argument(
        "--filter",
        help="Filter classnames by regex",
        default=".*",
    )
    parser.add_argument(
        "--log-level",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        default="WARNING",
    )
    parser.add_argument(
        "--outdir",
        help="Output directory for generated code",
        default="generated",
    )
    args = parser.parse_args()
    logging.basicConfig(level=args.log_level)
    shutil.rmtree(args.outdir, ignore_errors=True)
    os.makedirs(args.outdir, exist_ok=True)
    typetree = dict()
    if args.files:
        generator = TypeTreeGenerator(args.unity_version)
        generator.load_local_dll_folder(args.files)
        defines = generator.get_monobehavior_definitions()
        typetree = dict()
        for module, fullname in defines:
            try:
                nodes = generator.get_nodes(module, fullname)
                typetree[fullname] = nodes
            except Exception as e:
                logger.error(f"Failed to generate typetree for {module}.{fullname}")
    if args.json:
        with open(args.json, "r") as f:
            typetree = json.load(f)
    if typetree:
        regex = re.compile(args.filter)
        typetree = {k: v for k, v in typetree.items() if regex.match(k)}
        process_typetree(typetree, args.outdir)
        return 0
    return -1


import sys

if __name__ == "__main__":
    sys.exit(__main__())
